Skip to content

Commit 10e5482

Browse files
committed
[CLN] Moving vertex_overlap code to the module
1 parent b58188d commit 10e5482

File tree

3 files changed

+184
-72
lines changed

3 files changed

+184
-72
lines changed

gempy_engine/API/dual_contouring/multi_scalar_dual_contouring.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...core.data.octree_level import OctreeLevel
2020
from ...core.utils import gempy_profiler_decorator
2121
from ...modules.dual_contouring.dual_contouring_interface import (find_intersection_on_edge, get_triangulation_codes,
22-
get_masked_codes, mask_generation)
22+
get_masked_codes, mask_generation,apply_faults_vertex_overlap)
2323

2424

2525
@gempy_profiler_decorator
@@ -100,7 +100,7 @@ def dual_contouring_multi_scalar(
100100
# endregion
101101

102102
# region Vertex gen and triangulation
103-
foo = []
103+
left_right_per_mesh = []
104104
# Generate meshes for each scalar field
105105
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
106106
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
@@ -126,84 +126,22 @@ def dual_contouring_multi_scalar(
126126
)
127127

128128
for m in meshes:
129-
foo.append(m.left_right)
129+
left_right_per_mesh.append(m.left_right)
130130

131131
# TODO: If the order of the meshes does not match the order of scalar_field_at_surface points, reorder them here
132132
if meshes is not None:
133133
all_meshes.extend(meshes)
134134

135135
# endregion
136-
# Check for repeated voxels across stacks
137-
if (options.debug or len(all_left_right_codes) > 1) and False:
138-
voxel_overlaps = find_repeated_voxels_across_stacks(foo)
139-
if voxel_overlaps and options.debug:
140-
print(f"Found voxel overlaps between stacks: {voxel_overlaps}")
141-
_f(all_meshes, 1, 0, voxel_overlaps)
142-
_f(all_meshes, 2, 0, voxel_overlaps)
143-
_f(all_meshes, 3, 0, voxel_overlaps)
144-
_f(all_meshes, 4, 0, voxel_overlaps)
145-
_f(all_meshes, 5, 0, voxel_overlaps)
136+
if (options.debug or len(all_left_right_codes) > 1) and True:
137+
apply_faults_vertex_overlap(all_meshes, data_descriptor, left_right_per_mesh)
146138

147139
return all_meshes
148140

141+
# ... existing code ...
149142

150-
def _f(all_meshes: list[DualContouringMesh], destination: int, origin: int, voxel_overlaps: dict):
151-
key = f"stack_{origin}_vs_stack_{destination}"
152-
all_meshes[destination].vertices[voxel_overlaps[key]["indices_in_stack_j"]] = all_meshes[origin].vertices[voxel_overlaps[key]["indices_in_stack_i"]]
153143

154144

155-
def find_repeated_voxels_across_stacks(all_left_right_codes: List[np.ndarray]) -> dict:
156-
"""
157-
Find repeated voxels using NumPy operations - better for very large arrays.
158-
159-
Args:
160-
all_left_right_codes: List of left_right_codes arrays, one per stack
161-
162-
Returns:
163-
Dictionary with detailed overlap analysis
164-
"""
165-
166-
if not all_left_right_codes:
167-
return {}
168-
169-
# Generate voxel codes for each stack
170-
171-
from gempy_engine.modules.dual_contouring.fancy_triangulation import _StaticTriangulationData
172-
stack_codes = []
173-
for left_right_codes in all_left_right_codes:
174-
if left_right_codes.size > 0:
175-
voxel_codes = (left_right_codes * _StaticTriangulationData.get_pack_directions_into_bits()).sum(axis=1)
176-
stack_codes.append(voxel_codes)
177-
else:
178-
stack_codes.append(np.array([]))
179-
180-
overlaps = {}
181-
182-
# Check each pair of stacks
183-
for i in range(len(stack_codes)):
184-
for j in range(i + 1, len(stack_codes)):
185-
if stack_codes[i].size == 0 or stack_codes[j].size == 0:
186-
continue
187-
188-
# Find common voxel codes using numpy
189-
common_codes = np.intersect1d(stack_codes[i], stack_codes[j])
190-
191-
if len(common_codes) > 0:
192-
# Get indices of common voxels in each stack
193-
indices_i = np.isin(stack_codes[i], common_codes)
194-
indices_j = np.isin(stack_codes[j], common_codes)
195-
196-
overlaps[f"stack_{i}_vs_stack_{j}"] = {
197-
'common_voxel_codes' : common_codes,
198-
'count' : len(common_codes),
199-
'indices_in_stack_i' : np.where(indices_i)[0],
200-
'indices_in_stack_j' : np.where(indices_j)[0],
201-
'common_binary_codes_i': all_left_right_codes[i][indices_i],
202-
'common_binary_codes_j': all_left_right_codes[j][indices_j]
203-
}
204-
205-
return overlaps
206-
207145

208146
def _validate_stack_relations(data_descriptor: InputDataDescriptor, n_scalar_field: int) -> None:
209147
"""
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from typing import List
2+
3+
import numpy as np
4+
5+
from ...core.data.dual_contouring_mesh import DualContouringMesh
6+
7+
8+
9+
def _apply_fault_relations_to_overlaps(
10+
all_meshes: List[DualContouringMesh],
11+
faults_relations: np.ndarray,
12+
voxel_overlaps: dict,
13+
n_stacks: int
14+
) -> None:
15+
"""
16+
Apply fault relations to voxel overlaps by updating mesh vertices.
17+
18+
Args:
19+
all_meshes: List of dual contouring meshes
20+
faults_relations: Boolean matrix indicating fault relationships between stacks
21+
voxel_overlaps: Dictionary containing overlap information between stacks
22+
n_stacks: Total number of stacks
23+
"""
24+
if faults_relations is None:
25+
return
26+
27+
# Calculate mesh indices offset for each stack
28+
mesh_indices_offset = _calculate_mesh_indices_offset(all_meshes, n_stacks)
29+
30+
# Iterate through fault relations matrix
31+
for origin_stack in range(n_stacks):
32+
for destination_stack in range(n_stacks):
33+
# If there's a fault relation from origin to destination
34+
if faults_relations[origin_stack, destination_stack]:
35+
overlap_key = f"stack_{origin_stack}_vs_stack_{destination_stack}"
36+
37+
# Check if there are actual overlaps between these stacks
38+
if overlap_key in voxel_overlaps:
39+
_apply_vertex_sharing(
40+
all_meshes,
41+
origin_stack,
42+
destination_stack,
43+
voxel_overlaps[overlap_key],
44+
mesh_indices_offset
45+
)
46+
47+
48+
def _calculate_mesh_indices_offset(all_meshes: List[DualContouringMesh], n_stacks: int) -> List[int]:
49+
"""
50+
Calculate the starting mesh index for each stack.
51+
52+
Args:
53+
all_meshes: List of all dual contouring meshes
54+
n_stacks: Total number of stacks
55+
56+
Returns:
57+
List of starting mesh indices for each stack
58+
"""
59+
# For now, assume each stack has one mesh (this may need adjustment based on actual structure)
60+
# This is a simplified approach - you may need to adjust based on how meshes are organized
61+
mesh_indices_offset = list(range(n_stacks))
62+
return mesh_indices_offset
63+
64+
65+
def _apply_vertex_sharing(
66+
all_meshes: List[DualContouringMesh],
67+
origin_stack: int,
68+
destination_stack: int,
69+
overlap_data: dict,
70+
mesh_indices_offset: List[int]
71+
) -> None:
72+
"""
73+
Apply vertex sharing between origin and destination meshes based on overlap data.
74+
75+
Args:
76+
all_meshes: List of dual contouring meshes
77+
origin_stack: Stack index that serves as the source of vertices
78+
destination_stack: Stack index that receives vertices from origin
79+
overlap_data: Dictionary containing indices and overlap information
80+
mesh_indices_offset: Starting mesh index for each stack
81+
"""
82+
origin_mesh_idx = mesh_indices_offset[origin_stack]
83+
destination_mesh_idx = mesh_indices_offset[destination_stack]
84+
85+
# Ensure mesh indices are valid
86+
if (origin_mesh_idx >= len(all_meshes) or
87+
destination_mesh_idx >= len(all_meshes)):
88+
return
89+
90+
# Apply the vertex sharing (same logic as original _f function)
91+
origin_mesh = all_meshes[origin_mesh_idx]
92+
destination_mesh = all_meshes[destination_mesh_idx]
93+
94+
indices_in_origin = overlap_data["indices_in_stack_i"]
95+
indices_in_destination = overlap_data["indices_in_stack_j"]
96+
97+
destination_mesh.vertices[indices_in_destination] = origin_mesh.vertices[indices_in_origin]
98+
99+
100+
def _f(all_meshes: list[DualContouringMesh], destination: int, origin: int, voxel_overlaps: dict):
101+
"""
102+
Legacy function - kept for backward compatibility.
103+
Consider using _apply_fault_relations_to_overlaps for new implementations.
104+
"""
105+
key = f"stack_{origin}_vs_stack_{destination}"
106+
if key in voxel_overlaps:
107+
all_meshes[destination].vertices[voxel_overlaps[key]["indices_in_stack_j"]] = all_meshes[origin].vertices[voxel_overlaps[key]["indices_in_stack_i"]]
108+
109+
# def _f(all_meshes: list[DualContouringMesh], destination: int, origin: int, voxel_overlaps: dict):
110+
# key = f"stack_{origin}_vs_stack_{destination}"
111+
# all_meshes[destination].vertices[voxel_overlaps[key]["indices_in_stack_j"]] = all_meshes[origin].vertices[voxel_overlaps[key]["indices_in_stack_i"]]
112+
113+
114+
def find_repeated_voxels_across_stacks(all_left_right_codes: List[np.ndarray]) -> dict:
115+
"""
116+
Find repeated voxels using NumPy operations - better for very large arrays.
117+
118+
Args:
119+
all_left_right_codes: List of left_right_codes arrays, one per stack
120+
121+
Returns:
122+
Dictionary with detailed overlap analysis
123+
"""
124+
125+
if not all_left_right_codes:
126+
return {}
127+
128+
# Generate voxel codes for each stack
129+
130+
from gempy_engine.modules.dual_contouring.fancy_triangulation import _StaticTriangulationData
131+
stack_codes = []
132+
for left_right_codes in all_left_right_codes:
133+
if left_right_codes.size > 0:
134+
voxel_codes = (left_right_codes * _StaticTriangulationData.get_pack_directions_into_bits()).sum(axis=1)
135+
stack_codes.append(voxel_codes)
136+
else:
137+
stack_codes.append(np.array([]))
138+
139+
overlaps = {}
140+
141+
# Check each pair of stacks
142+
for i in range(len(stack_codes)):
143+
for j in range(i + 1, len(stack_codes)):
144+
if stack_codes[i].size == 0 or stack_codes[j].size == 0:
145+
continue
146+
147+
# Find common voxel codes using numpy
148+
common_codes = np.intersect1d(stack_codes[i], stack_codes[j])
149+
150+
if len(common_codes) > 0:
151+
# Get indices of common voxels in each stack
152+
indices_i = np.isin(stack_codes[i], common_codes)
153+
indices_j = np.isin(stack_codes[j], common_codes)
154+
155+
overlaps[f"stack_{i}_vs_stack_{j}"] = {
156+
'common_voxel_codes' : common_codes,
157+
'count' : len(common_codes),
158+
'indices_in_stack_i' : np.where(indices_i)[0],
159+
'indices_in_stack_j' : np.where(indices_j)[0],
160+
'common_binary_codes_i': all_left_right_codes[i][indices_i],
161+
'common_binary_codes_j': all_left_right_codes[j][indices_j]
162+
}
163+
164+
return overlaps

gempy_engine/modules/dual_contouring/dual_contouring_interface.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import numpy as np
55

6+
from ._vertex_overlap import find_repeated_voxels_across_stacks, _apply_fault_relations_to_overlaps
67
from .fancy_triangulation import get_left_right_array
78
from ...core.backend_tensor import BackendTensor
89
from ...core.data import InterpolationOptions
10+
from ...core.data.dual_contouring_mesh import DualContouringMesh
911
from ...core.data.input_data_descriptor import InputDataDescriptor
1012
from ...core.data.interp_output import InterpOutput
1113
from ...core.data.octree_level import OctreeLevel
@@ -109,8 +111,7 @@ def get_triangulation_codes(octree_list: List[OctreeLevel], options: Interpolati
109111
raise ValueError("Invalid combination of options")
110112

111113

112-
113-
def get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | None) -> np.ndarray | None:
114+
def get_masked_codes(left_right_codes: np.ndarray, mask: np.ndarray) -> np.ndarray:
114115
"""
115116
Apply mask to left-right codes if both are available.
116117
@@ -133,7 +134,7 @@ def get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | Non
133134
def mask_generation(
134135
octree_leaves: OctreeLevel,
135136
masking_option: MeshExtractionMaskingOptions
136-
) -> np.ndarray | None:
137+
) -> np.ndarray:
137138
"""
138139
Generate masks for mesh extraction based on masking options and stack relations.
139140
@@ -196,4 +197,13 @@ def mask_generation(
196197
return mask_matrix
197198

198199

199-
# endregion
200+
# endregion
201+
def apply_faults_vertex_overlap(all_meshes: list[DualContouringMesh],
202+
data_descriptor: InputDataDescriptor,
203+
left_right_per_mesh: list[np.ndarray]):
204+
faults_relations = data_descriptor.stack_structure.faults_relations
205+
voxel_overlaps = find_repeated_voxels_across_stacks(left_right_per_mesh)
206+
207+
if voxel_overlaps:
208+
print(f"Found voxel overlaps between stacks: {voxel_overlaps}")
209+
_apply_fault_relations_to_overlaps(all_meshes, faults_relations, voxel_overlaps, data_descriptor.stack_structure.n_stacks)

0 commit comments

Comments
 (0)