Skip to content

Commit 5f6904a

Browse files
authored
[ENH] Implement fault vertex overlap handling in dual contouring (#29)
# Refactor and improve fault vertex overlap handling in dual contouring This PR refactors the vertex overlap handling for fault surfaces in the dual contouring module: - Extracts vertex overlap detection and handling into a dedicated `_vertex_overlap.py` module - Improves the implementation of fault relations application to mesh vertices - Renames variables for better clarity (e.g., `foo` → `left_right_per_mesh`) - Adds proper function to apply fault vertex overlap with improved structure - Enhances code organization with helper functions for better readability - Updates test parameters for more reliable fault model testing The changes improve the handling of fault surfaces in the dual contouring algorithm by ensuring proper vertex sharing between meshes that represent fault surfaces.
2 parents b58188d + 986c7d4 commit 5f6904a

File tree

4 files changed

+198
-73
lines changed

4 files changed

+198
-73
lines changed

gempy_engine/API/dual_contouring/multi_scalar_dual_contouring.py

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ...core.data.octree_level import OctreeLevel
2020
from ...core.utils import gempy_profiler_decorator
2121
from ...modules.dual_contouring.dual_contouring_interface import (find_intersection_on_edge, get_triangulation_codes,
22-
get_masked_codes, mask_generation)
22+
get_masked_codes, mask_generation,apply_faults_vertex_overlap)
2323

2424

2525
@gempy_profiler_decorator
@@ -100,7 +100,7 @@ def dual_contouring_multi_scalar(
100100
# endregion
101101

102102
# region Vertex gen and triangulation
103-
foo = []
103+
left_right_per_mesh = []
104104
# Generate meshes for each scalar field
105105
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
106106
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
@@ -126,84 +126,22 @@ def dual_contouring_multi_scalar(
126126
)
127127

128128
for m in meshes:
129-
foo.append(m.left_right)
129+
left_right_per_mesh.append(m.left_right)
130130

131131
# TODO: If the order of the meshes does not match the order of scalar_field_at_surface points, reorder them here
132132
if meshes is not None:
133133
all_meshes.extend(meshes)
134134

135135
# endregion
136-
# Check for repeated voxels across stacks
137136
if (options.debug or len(all_left_right_codes) > 1) and False:
138-
voxel_overlaps = find_repeated_voxels_across_stacks(foo)
139-
if voxel_overlaps and options.debug:
140-
print(f"Found voxel overlaps between stacks: {voxel_overlaps}")
141-
_f(all_meshes, 1, 0, voxel_overlaps)
142-
_f(all_meshes, 2, 0, voxel_overlaps)
143-
_f(all_meshes, 3, 0, voxel_overlaps)
144-
_f(all_meshes, 4, 0, voxel_overlaps)
145-
_f(all_meshes, 5, 0, voxel_overlaps)
137+
apply_faults_vertex_overlap(all_meshes, data_descriptor.stack_structure, left_right_per_mesh)
146138

147139
return all_meshes
148140

141+
# ... existing code ...
149142

150-
def _f(all_meshes: list[DualContouringMesh], destination: int, origin: int, voxel_overlaps: dict):
151-
key = f"stack_{origin}_vs_stack_{destination}"
152-
all_meshes[destination].vertices[voxel_overlaps[key]["indices_in_stack_j"]] = all_meshes[origin].vertices[voxel_overlaps[key]["indices_in_stack_i"]]
153143

154144

155-
def find_repeated_voxels_across_stacks(all_left_right_codes: List[np.ndarray]) -> dict:
156-
"""
157-
Find repeated voxels using NumPy operations - better for very large arrays.
158-
159-
Args:
160-
all_left_right_codes: List of left_right_codes arrays, one per stack
161-
162-
Returns:
163-
Dictionary with detailed overlap analysis
164-
"""
165-
166-
if not all_left_right_codes:
167-
return {}
168-
169-
# Generate voxel codes for each stack
170-
171-
from gempy_engine.modules.dual_contouring.fancy_triangulation import _StaticTriangulationData
172-
stack_codes = []
173-
for left_right_codes in all_left_right_codes:
174-
if left_right_codes.size > 0:
175-
voxel_codes = (left_right_codes * _StaticTriangulationData.get_pack_directions_into_bits()).sum(axis=1)
176-
stack_codes.append(voxel_codes)
177-
else:
178-
stack_codes.append(np.array([]))
179-
180-
overlaps = {}
181-
182-
# Check each pair of stacks
183-
for i in range(len(stack_codes)):
184-
for j in range(i + 1, len(stack_codes)):
185-
if stack_codes[i].size == 0 or stack_codes[j].size == 0:
186-
continue
187-
188-
# Find common voxel codes using numpy
189-
common_codes = np.intersect1d(stack_codes[i], stack_codes[j])
190-
191-
if len(common_codes) > 0:
192-
# Get indices of common voxels in each stack
193-
indices_i = np.isin(stack_codes[i], common_codes)
194-
indices_j = np.isin(stack_codes[j], common_codes)
195-
196-
overlaps[f"stack_{i}_vs_stack_{j}"] = {
197-
'common_voxel_codes' : common_codes,
198-
'count' : len(common_codes),
199-
'indices_in_stack_i' : np.where(indices_i)[0],
200-
'indices_in_stack_j' : np.where(indices_j)[0],
201-
'common_binary_codes_i': all_left_right_codes[i][indices_i],
202-
'common_binary_codes_j': all_left_right_codes[j][indices_j]
203-
}
204-
205-
return overlaps
206-
207145

208146
def _validate_stack_relations(data_descriptor: InputDataDescriptor, n_scalar_field: int) -> None:
209147
"""
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
from typing import List, Dict, Optional
2+
3+
import numpy as np
4+
5+
from ...core.data.dual_contouring_mesh import DualContouringMesh
6+
from ...core.data.stacks_structure import StacksStructure
7+
8+
9+
def _apply_fault_relations_to_overlaps(
10+
all_meshes: List[DualContouringMesh],
11+
voxel_overlaps: Dict[str, dict],
12+
stacks_structure: StacksStructure
13+
) -> None:
14+
"""
15+
Apply fault relations to voxel overlaps by updating mesh vertices.
16+
17+
Args:
18+
all_meshes: List of dual contouring meshes
19+
voxel_overlaps: Dictionary containing overlap information between stacks
20+
stacks_structure: Structure containing fault relations and stack information
21+
"""
22+
if stacks_structure.faults_relations is None:
23+
return
24+
25+
faults_relations = stacks_structure.faults_relations
26+
n_stacks = stacks_structure.n_stacks
27+
surfaces_per_stack = stacks_structure.number_of_surfaces_per_stack_vector
28+
29+
# Process fault relations
30+
for origin_stack, destination_stack in _get_fault_pairs(faults_relations, n_stacks):
31+
surface_range = _get_surface_range(surfaces_per_stack, destination_stack)
32+
33+
for surface_n in surface_range:
34+
overlap_key = f"stack_{origin_stack}_vs_stack_{surface_n}"
35+
36+
if overlap_key in voxel_overlaps:
37+
_apply_vertex_sharing(
38+
all_meshes,
39+
origin_stack,
40+
surface_n,
41+
voxel_overlaps[overlap_key]
42+
)
43+
44+
45+
def _get_fault_pairs(faults_relations: np.ndarray, n_stacks: int):
46+
"""Generate pairs of stacks that have fault relations."""
47+
for origin_stack in range(n_stacks):
48+
for destination_stack in range(n_stacks):
49+
if faults_relations[origin_stack, destination_stack]:
50+
yield origin_stack, destination_stack
51+
52+
53+
def _get_surface_range(surfaces_per_stack: np.ndarray, stack_index: int) -> range:
54+
"""Get the range of surfaces for a given stack."""
55+
return range(
56+
surfaces_per_stack[stack_index],
57+
surfaces_per_stack[stack_index + 1]
58+
)
59+
60+
61+
def _apply_vertex_sharing(
62+
all_meshes: List[DualContouringMesh],
63+
origin_mesh_idx: int,
64+
destination_mesh_idx: int,
65+
overlap_data: dict
66+
) -> None:
67+
"""
68+
Apply vertex sharing between origin and destination meshes based on overlap data.
69+
70+
Args:
71+
all_meshes: List of dual contouring meshes
72+
origin_mesh_idx: Index of mesh that serves as the source of vertices
73+
destination_mesh_idx: Index of mesh that receives vertices from origin
74+
overlap_data: Dictionary containing indices and overlap information
75+
"""
76+
if not _are_valid_mesh_indices(all_meshes, origin_mesh_idx, destination_mesh_idx):
77+
return
78+
79+
origin_mesh = all_meshes[origin_mesh_idx]
80+
destination_mesh = all_meshes[destination_mesh_idx]
81+
82+
# Share vertices from origin to destination
83+
origin_indices = overlap_data["indices_in_stack_i"]
84+
destination_indices = overlap_data["indices_in_stack_j"]
85+
86+
destination_mesh.vertices[destination_indices] = origin_mesh.vertices[origin_indices]
87+
88+
89+
def _are_valid_mesh_indices(all_meshes: List[DualContouringMesh], *indices: int) -> bool:
90+
"""Check if all provided mesh indices are valid."""
91+
return all(0 <= idx < len(all_meshes) for idx in indices)
92+
93+
94+
def find_repeated_voxels_across_stacks(all_left_right_codes: List[np.ndarray]) -> Dict[str, dict]:
95+
"""
96+
Find repeated voxels using NumPy operations for efficient processing of large arrays.
97+
98+
Args:
99+
all_left_right_codes: List of left_right_codes arrays, one per stack
100+
101+
Returns:
102+
Dictionary with detailed overlap analysis between stack pairs
103+
"""
104+
if not all_left_right_codes:
105+
return {}
106+
107+
stack_codes = _generate_voxel_codes(all_left_right_codes)
108+
return _find_overlaps_between_stacks(stack_codes, all_left_right_codes)
109+
110+
111+
def _generate_voxel_codes(all_left_right_codes: List[np.ndarray]) -> List[np.ndarray]:
112+
"""Generate voxel codes for each stack using packed bit directions."""
113+
from gempy_engine.modules.dual_contouring.fancy_triangulation import _StaticTriangulationData
114+
115+
pack_directions = _StaticTriangulationData.get_pack_directions_into_bits()
116+
stack_codes = []
117+
118+
for left_right_codes in all_left_right_codes:
119+
if left_right_codes.size > 0:
120+
voxel_codes = (left_right_codes * pack_directions).sum(axis=1)
121+
stack_codes.append(voxel_codes)
122+
else:
123+
stack_codes.append(np.array([]))
124+
125+
return stack_codes
126+
127+
128+
def _find_overlaps_between_stacks(
129+
stack_codes: List[np.ndarray],
130+
all_left_right_codes: List[np.ndarray]
131+
) -> Dict[str, dict]:
132+
"""Find overlaps between all pairs of stacks."""
133+
overlaps = {}
134+
135+
for i in range(len(stack_codes)):
136+
for j in range(i + 1, len(stack_codes)):
137+
overlap_data = _process_stack_pair(
138+
stack_codes[i], stack_codes[j],
139+
all_left_right_codes[i], all_left_right_codes[j],
140+
i, j
141+
)
142+
143+
if overlap_data:
144+
overlaps[f"stack_{i}_vs_stack_{j}"] = overlap_data
145+
146+
return overlaps
147+
148+
149+
def _process_stack_pair(
150+
codes_i: np.ndarray,
151+
codes_j: np.ndarray,
152+
left_right_i: np.ndarray,
153+
left_right_j: np.ndarray,
154+
stack_i: int,
155+
stack_j: int
156+
) -> Optional[dict]:
157+
"""Process a pair of stacks to find overlapping voxels."""
158+
if codes_i.size == 0 or codes_j.size == 0:
159+
return None
160+
161+
common_codes = np.intersect1d(codes_i, codes_j)
162+
163+
if len(common_codes) == 0:
164+
return None
165+
166+
# Find indices of common voxels in each stack
167+
indices_i = np.isin(codes_i, common_codes)
168+
indices_j = np.isin(codes_j, common_codes)
169+
170+
return {
171+
'common_voxel_codes': common_codes,
172+
'count': len(common_codes),
173+
'indices_in_stack_i': np.where(indices_i)[0],
174+
'indices_in_stack_j': np.where(indices_j)[0],
175+
'common_binary_codes_i': left_right_i[indices_i],
176+
'common_binary_codes_j': left_right_j[indices_j]
177+
}

gempy_engine/modules/dual_contouring/dual_contouring_interface.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
import numpy as np
55

6+
from ._vertex_overlap import find_repeated_voxels_across_stacks, _apply_fault_relations_to_overlaps
67
from .fancy_triangulation import get_left_right_array
78
from ...core.backend_tensor import BackendTensor
89
from ...core.data import InterpolationOptions
10+
from ...core.data.dual_contouring_mesh import DualContouringMesh
911
from ...core.data.input_data_descriptor import InputDataDescriptor
1012
from ...core.data.interp_output import InterpOutput
1113
from ...core.data.octree_level import OctreeLevel
1214
from ...core.data.options import MeshExtractionMaskingOptions
1315
from ...core.data.stack_relation_type import StackRelationType
16+
from ...core.data.stacks_structure import StacksStructure
1417

1518

1619
# region edges
@@ -109,8 +112,7 @@ def get_triangulation_codes(octree_list: List[OctreeLevel], options: Interpolati
109112
raise ValueError("Invalid combination of options")
110113

111114

112-
113-
def get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | None) -> np.ndarray | None:
115+
def get_masked_codes(left_right_codes: np.ndarray, mask: np.ndarray) -> np.ndarray:
114116
"""
115117
Apply mask to left-right codes if both are available.
116118
@@ -133,7 +135,7 @@ def get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | Non
133135
def mask_generation(
134136
octree_leaves: OctreeLevel,
135137
masking_option: MeshExtractionMaskingOptions
136-
) -> np.ndarray | None:
138+
) -> np.ndarray:
137139
"""
138140
Generate masks for mesh extraction based on masking options and stack relations.
139141
@@ -196,4 +198,12 @@ def mask_generation(
196198
return mask_matrix
197199

198200

199-
# endregion
201+
# endregion
202+
def apply_faults_vertex_overlap(all_meshes: list[DualContouringMesh],
203+
stack_structure: StacksStructure,
204+
left_right_per_mesh: list[np.ndarray]):
205+
voxel_overlaps = find_repeated_voxels_across_stacks(left_right_per_mesh)
206+
207+
if voxel_overlaps:
208+
print(f"Found voxel overlaps between stacks: {voxel_overlaps}")
209+
_apply_fault_relations_to_overlaps(all_meshes, voxel_overlaps, stack_structure)

tests/test_common/test_modules/test_dual_II.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from tests.conftest import plot_pyvista
1010

1111

12-
def test_dual_contouring_on_fault_model(one_fault_model, n_oct_levels=4):
12+
def test_dual_contouring_on_fault_model(one_fault_model, n_oct_levels=5):
1313
interpolation_input: InterpolationInput
1414
structure: InputDataDescriptor
1515
options: InterpolationOptions
1616

1717
interpolation_input, structure, options = one_fault_model
1818

1919
import numpy as np
20-
interpolation_input.surface_points.sp_coords[:, 2] += np.random.uniform(-0.1, 0.1, interpolation_input.surface_points.sp_coords[:, 2].shape)
20+
interpolation_input.surface_points.sp_coords[:, 2] += np.random.uniform(-0.02, 0.02, interpolation_input.surface_points.sp_coords[:, 2].shape)
2121
options.compute_scalar_gradient = False
2222
options.evaluation_options.dual_contouring = True
2323
options.evaluation_options.mesh_extraction_masking_options = MeshExtractionMaskingOptions.INTERSECT

0 commit comments

Comments
 (0)