Skip to content

Commit bcac51d

Browse files
authored
[ENH] Refactor dual contouring code with improved documentation and modular functions (#25)
# Refactor dual contouring implementation for improved clarity and maintainability This PR refactors the `dual_contouring_multi_scalar` function to improve code organization, readability, and maintainability. The changes include: - Added comprehensive docstrings to explain function purpose and parameters - Extracted helper functions to break down complex logic into smaller, focused components: - `_get_triangulation_codes`: Determines appropriate triangulation strategy - `_validate_stack_relations`: Validates stack relation configurations - `_get_masked_codes`: Applies masks to triangulation codes - `_interp_on_edges`: Handles interpolation on edge intersection points - Enhanced `_mask_generation` with better documentation The refactoring preserves all existing functionality while making the code more modular and easier to understand. Edge intersection detection and processing are now more clearly separated, improving the overall architecture.
2 parents 4e2833d + 6c8e3b6 commit bcac51d

File tree

1 file changed

+216
-58
lines changed

1 file changed

+216
-58
lines changed
Lines changed: 216 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import warnings
3-
from typing import List
3+
from typing import List, Any
44

55
import numpy as np
66

@@ -19,12 +19,29 @@
1919
from ...core.data.options import MeshExtractionMaskingOptions
2020
from ...core.data.stack_relation_type import StackRelationType
2121
from ...core.utils import gempy_profiler_decorator
22+
from ...modules.dual_contouring.dual_contouring_interface import find_intersection_on_edge
2223
from ...modules.dual_contouring.fancy_triangulation import get_left_right_array
2324

2425

2526
@gempy_profiler_decorator
26-
def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpolation_input: InterpolationInput,
27-
options: InterpolationOptions, octree_list: list[OctreeLevel]) -> List[DualContouringMesh]:
27+
def dual_contouring_multi_scalar(
28+
data_descriptor: InputDataDescriptor,
29+
interpolation_input: InterpolationInput,
30+
options: InterpolationOptions,
31+
octree_list: List[OctreeLevel]
32+
) -> List[DualContouringMesh]:
33+
"""
34+
Perform dual contouring for multiple scalar fields.
35+
36+
Args:
37+
data_descriptor: Input data descriptor containing stack structure information
38+
interpolation_input: Input data for interpolation
39+
options: Interpolation options including debug and extraction settings
40+
octree_list: List of octree levels with the last being the leaf level
41+
42+
Returns:
43+
List of dual contouring meshes for all processed scalar fields
44+
"""
2845
# Dual Contouring prep:
2946
MaskBuffer.clean()
3047

@@ -35,110 +52,251 @@ def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpola
3552
dual_contouring_options.evaluation_options.compute_scalar_gradient = True
3653

3754
if options.debug_water_tight:
38-
_experimental_water_tight(all_meshes, data_descriptor, interpolation_input, octree_leaves, dual_contouring_options)
55+
_experimental_water_tight(
56+
all_meshes, data_descriptor, interpolation_input, octree_leaves, dual_contouring_options
57+
)
3958
return all_meshes
4059

41-
# region new triangulations
42-
is_pure_octree = bool(np.all(octree_list[0].grid_centers.octree_grid_shape == 2))
43-
match (options.evaluation_options.mesh_extraction_fancy, is_pure_octree):
44-
case (True, True):
45-
left_right_codes = get_left_right_array(octree_list)
46-
case (True, False):
47-
left_right_codes = None
48-
warnings.warn("Fancy triangulation only works with regular grid of resolution [2,2,2]. Defaulting to regular triangulation")
49-
case (False, _):
50-
left_right_codes = None
51-
case _:
52-
raise ValueError("Invalid combination of options")
53-
# endregion
60+
# Determine triangulation strategy
61+
left_right_codes = _get_triangulation_codes(octree_list, options)
5462

63+
# Generate masks for all scalar fields
5564
all_mask_arrays: np.ndarray = _mask_generation(
5665
octree_leaves=octree_leaves,
5766
masking_option=options.evaluation_options.mesh_extraction_masking_options
5867
)
5968

69+
# Process each scalar field
70+
all_stack_intersection = []
71+
all_valid_edges = []
72+
all_left_right_codes = []
73+
6074
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
61-
previous_stack_is_onlap = data_descriptor.stack_relation[n_scalar_field - 1] == 'Onlap'
62-
was_erosion_before = data_descriptor.stack_relation[n_scalar_field - 1] == 'Erosion'
63-
if previous_stack_is_onlap and was_erosion_before: # ? (July, 2023) Is this still valid? I thought we have all the combinations
64-
raise NotImplementedError("Erosion and Onlap are not supported yet")
65-
pass
75+
_validate_stack_relations(data_descriptor, n_scalar_field)
6676

6777
mask: np.ndarray = all_mask_arrays[n_scalar_field]
78+
left_right_codes_per_stack = _get_masked_codes(left_right_codes, mask)
6879

69-
if mask is not None and left_right_codes is not None:
70-
left_right_codes_per_stack = left_right_codes[mask]
71-
else:
72-
left_right_codes_per_stack = left_right_codes
73-
74-
# @off
75-
dc_data: DualContouringData = interpolate_on_edges_for_dual_contouring(
76-
data_descriptor=data_descriptor,
77-
interpolation_input=interpolation_input,
78-
options=dual_contouring_options,
79-
n_scalar_field=n_scalar_field,
80-
octree_leaves=octree_leaves,
81-
mask=mask
80+
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
81+
intersection_xyz, valid_edges = find_intersection_on_edge(
82+
_xyz_corners=octree_leaves.grid_centers.corners_grid.values,
83+
scalar_field_on_corners=output.exported_fields.scalar_field[output.grid.corners_grid_slice],
84+
scalar_at_sp=output.scalar_field_at_sp,
85+
masking=mask
8286
)
8387

88+
all_stack_intersection.append(intersection_xyz)
89+
all_valid_edges.append(valid_edges)
90+
all_left_right_codes.append(left_right_codes_per_stack)
91+
92+
# Interpolate on edges for all stacks
93+
output_on_edges = _interp_on_edges(
94+
all_stack_intersection, data_descriptor, dual_contouring_options, interpolation_input
95+
)
96+
97+
# Generate meshes for each scalar field
98+
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
99+
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
100+
mask = all_mask_arrays[n_scalar_field]
101+
102+
dc_data = DualContouringData(
103+
xyz_on_edge=all_stack_intersection[n_scalar_field],
104+
valid_edges=all_valid_edges[n_scalar_field],
105+
xyz_on_centers=(
106+
octree_leaves.grid_centers.octree_grid.values if mask is None
107+
else octree_leaves.grid_centers.octree_grid.values[mask]
108+
),
109+
dxdydz=octree_leaves.grid_centers.octree_dxdydz,
110+
exported_fields_on_edges=output_on_edges[n_scalar_field].exported_fields,
111+
n_surfaces_to_export=output.scalar_field_at_sp.shape[0],
112+
tree_depth=options.number_octree_levels,
113+
)
114+
84115
meshes: List[DualContouringMesh] = compute_dual_contouring(
85116
dc_data_per_stack=dc_data,
86-
left_right_codes=left_right_codes_per_stack,
117+
left_right_codes=all_left_right_codes[n_scalar_field],
87118
debug=options.debug
88119
)
89120

90-
# ! If the order of the meshes does not match the order of scalar_field_at_surface points we need to reorder them HERE
91-
121+
# TODO: If the order of the meshes does not match the order of scalar_field_at_surface points, reorder them here
92122
if meshes is not None:
93123
all_meshes.extend(meshes)
94-
# @on
95124

96125
return all_meshes
97126

98127

99-
def _mask_generation(octree_leaves, masking_option: MeshExtractionMaskingOptions) -> np.ndarray | None:
100-
all_scalar_fields_outputs: list[InterpOutput] = octree_leaves.outputs_centers
128+
def _get_triangulation_codes(octree_list: List[OctreeLevel], options: InterpolationOptions) -> np.ndarray | None:
129+
"""
130+
Determine the appropriate triangulation codes based on options and octree structure.
131+
132+
Args:
133+
octree_list: List of octree levels
134+
options: Interpolation options
135+
136+
Returns:
137+
Left-right codes array if fancy triangulation is enabled and supported, None otherwise
138+
"""
139+
is_pure_octree = bool(np.all(octree_list[0].grid_centers.octree_grid_shape == 2))
140+
141+
match (options.evaluation_options.mesh_extraction_fancy, is_pure_octree):
142+
case (True, True):
143+
return get_left_right_array(octree_list)
144+
case (True, False):
145+
warnings.warn(
146+
"Fancy triangulation only works with regular grid of resolution [2,2,2]. "
147+
"Defaulting to regular triangulation"
148+
)
149+
return None
150+
case (False, _):
151+
return None
152+
case _:
153+
raise ValueError("Invalid combination of options")
154+
155+
156+
def _validate_stack_relations(data_descriptor: InputDataDescriptor, n_scalar_field: int) -> None:
157+
"""
158+
Validate stack relations for the given scalar field.
159+
160+
Args:
161+
data_descriptor: Input data descriptor containing stack relations
162+
n_scalar_field: Current scalar field index
163+
164+
Raises:
165+
NotImplementedError: If unsupported combination of Erosion and Onlap is detected
166+
"""
167+
if n_scalar_field == 0:
168+
return
169+
170+
previous_stack_is_onlap = data_descriptor.stack_relation[n_scalar_field - 1] == 'Onlap'
171+
was_erosion_before = data_descriptor.stack_relation[n_scalar_field - 1] == 'Erosion'
172+
173+
if previous_stack_is_onlap and was_erosion_before:
174+
# TODO (July, 2023): Is this still valid? I thought we have all the combinations
175+
raise NotImplementedError("Erosion and Onlap are not supported yet")
176+
177+
178+
def _get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | None) -> np.ndarray | None:
179+
"""
180+
Apply mask to left-right codes if both are available.
181+
182+
Args:
183+
left_right_codes: Original left-right codes array
184+
mask: Boolean mask array
185+
186+
Returns:
187+
Masked codes if both inputs are not None, otherwise original codes
188+
"""
189+
if mask is not None and left_right_codes is not None:
190+
return left_right_codes[mask]
191+
return left_right_codes
192+
193+
194+
def _interp_on_edges(
195+
all_stack_intersection: List[Any],
196+
data_descriptor: InputDataDescriptor,
197+
dual_contouring_options: InterpolationOptions,
198+
interpolation_input: InterpolationInput
199+
) -> List[InterpOutput]:
200+
"""
201+
Interpolate scalar fields on edge intersection points.
202+
203+
Args:
204+
all_stack_intersection: List of intersection points for all stacks
205+
data_descriptor: Input data descriptor
206+
dual_contouring_options: Dual contouring specific options
207+
interpolation_input: Interpolation input data
208+
209+
Returns:
210+
List of interpolation outputs for each stack
211+
"""
212+
from ...core.data.engine_grid import EngineGrid
213+
from ...core.data.generic_grid import GenericGrid
214+
from ..interp_single.interp_features import interpolate_all_fields_no_octree
215+
216+
# Set temporary grid with concatenated intersection points
217+
interpolation_input.set_temp_grid(
218+
EngineGrid(
219+
custom_grid=GenericGrid(
220+
values=BackendTensor.t.concatenate(all_stack_intersection, axis=0)
221+
)
222+
)
223+
)
224+
225+
# TODO (@miguel 21 June): By definition in `interpolate_all_fields_no_octree`
226+
# we just need to interpolate up to the n_scalar_field, but need to test this
227+
# This should be done with buffer weights to avoid waste
228+
output_on_edges: List[InterpOutput] = interpolate_all_fields_no_octree(
229+
interpolation_input=interpolation_input,
230+
options=dual_contouring_options,
231+
data_descriptor=data_descriptor
232+
)
233+
234+
# Restore original grid
235+
interpolation_input.set_grid_to_original()
236+
return output_on_edges
237+
238+
239+
def _mask_generation(
240+
octree_leaves: OctreeLevel,
241+
masking_option: MeshExtractionMaskingOptions
242+
) -> np.ndarray | None:
243+
"""
244+
Generate masks for mesh extraction based on masking options and stack relations.
245+
246+
Args:
247+
octree_leaves: Octree leaf level containing scalar field outputs
248+
masking_option: Mesh extraction masking configuration
249+
250+
Returns:
251+
Matrix of boolean masks for each scalar field
252+
253+
Raises:
254+
NotImplementedError: For unsupported masking options
255+
ValueError: For invalid option combinations
256+
"""
257+
all_scalar_fields_outputs: List[InterpOutput] = octree_leaves.outputs_centers
101258
n_scalar_fields = len(all_scalar_fields_outputs)
102259
outputs_ = all_scalar_fields_outputs[0]
103260
slice_corners = outputs_.grid.corners_grid_slice
104261
grid_size = outputs_.cornersGrid_values.shape[0]
262+
105263
mask_matrix = BackendTensor.t.zeros((n_scalar_fields, grid_size // 8), dtype=bool)
106264
onlap_chain_counter = 0
107265

108266
for i in range(n_scalar_fields):
109267
stack_relation = all_scalar_fields_outputs[i].scalar_fields.stack_relation
268+
110269
match (masking_option, stack_relation):
111270
case MeshExtractionMaskingOptions.RAW, _:
112271
mask_matrix[i] = BackendTensor.t.ones(grid_size // 8, dtype=bool)
272+
113273
case MeshExtractionMaskingOptions.DISJOINT, _:
114-
raise NotImplementedError("Disjoint is not supported yet. Not even sure if there is anything to support")
115-
# case (MeshExtractionMaskingOptions.DISJOINT | MeshExtractionMaskingOptions.INTERSECT, StackRelationType.FAULT):
116-
# mask_matrix[i] = np.ones(grid_size//8, dtype=bool)
117-
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ERODE | StackRelationType.BASEMENT:
118-
# mask_scalar = all_scalar_fields_outputs[i - 1].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
119-
# if MaskBuffer.previous_mask is None:
120-
# mask = mask_scalar
121-
# else:
122-
# mask = (MaskBuffer.previous_mask ^ mask_scalar) * mask_scalar
123-
# MaskBuffer.previous_mask = mask
124-
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ONLAP:
125-
# raise NotImplementedError("Onlap is not supported yet")
126-
# return octree_leaves.outputs_corners[n_scalar_field].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
274+
raise NotImplementedError(
275+
"Disjoint is not supported yet. Not even sure if there is anything to support"
276+
)
277+
127278
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ERODE:
128-
x = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
279+
mask_array = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array
280+
x = mask_array[slice_corners].reshape((1, -1, 8))
129281
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
130282
onlap_chain_counter = 0
283+
131284
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.BASEMENT:
132-
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
285+
mask_array = all_scalar_fields_outputs[i].squeezed_mask_array
286+
x = mask_array[slice_corners].reshape((1, -1, 8))
133287
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
134288
onlap_chain_counter = 0
289+
135290
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ONLAP:
136-
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
291+
mask_array = all_scalar_fields_outputs[i].squeezed_mask_array
292+
x = mask_array[slice_corners].reshape((1, -1, 8))
137293
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
138294
onlap_chain_counter += 1
295+
139296
case _, StackRelationType.FAULT:
140297
mask_matrix[i] = BackendTensor.t.ones(grid_size // 8, dtype=bool)
298+
141299
case _:
142300
raise ValueError("Invalid combination of options")
143301

144-
return mask_matrix
302+
return mask_matrix

0 commit comments

Comments
 (0)