Skip to content

Commit 6c8e3b6

Browse files
committed
[CLN]
1 parent 2eb31e3 commit 6c8e3b6

File tree

1 file changed

+170
-56
lines changed

1 file changed

+170
-56
lines changed
Lines changed: 170 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import warnings
3-
from typing import List
3+
from typing import List, Any
44

55
import numpy as np
66

@@ -24,8 +24,24 @@
2424

2525

2626
@gempy_profiler_decorator
27-
def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpolation_input: InterpolationInput,
28-
options: InterpolationOptions, octree_list: list[OctreeLevel]) -> List[DualContouringMesh]:
27+
def dual_contouring_multi_scalar(
28+
data_descriptor: InputDataDescriptor,
29+
interpolation_input: InterpolationInput,
30+
options: InterpolationOptions,
31+
octree_list: List[OctreeLevel]
32+
) -> List[DualContouringMesh]:
33+
"""
34+
Perform dual contouring for multiple scalar fields.
35+
36+
Args:
37+
data_descriptor: Input data descriptor containing stack structure information
38+
interpolation_input: Input data for interpolation
39+
options: Interpolation options including debug and extraction settings
40+
octree_list: List of octree levels with the last being the leaf level
41+
42+
Returns:
43+
List of dual contouring meshes for all processed scalar fields
44+
"""
2945
# Dual Contouring prep:
3046
MaskBuffer.clean()
3147

@@ -36,45 +52,30 @@ def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpola
3652
dual_contouring_options.evaluation_options.compute_scalar_gradient = True
3753

3854
if options.debug_water_tight:
39-
_experimental_water_tight(all_meshes, data_descriptor, interpolation_input, octree_leaves, dual_contouring_options)
55+
_experimental_water_tight(
56+
all_meshes, data_descriptor, interpolation_input, octree_leaves, dual_contouring_options
57+
)
4058
return all_meshes
4159

42-
# region new triangulations
43-
is_pure_octree = bool(np.all(octree_list[0].grid_centers.octree_grid_shape == 2))
44-
match (options.evaluation_options.mesh_extraction_fancy, is_pure_octree):
45-
case (True, True):
46-
left_right_codes = get_left_right_array(octree_list)
47-
case (True, False):
48-
left_right_codes = None
49-
warnings.warn("Fancy triangulation only works with regular grid of resolution [2,2,2]. Defaulting to regular triangulation")
50-
case (False, _):
51-
left_right_codes = None
52-
case _:
53-
raise ValueError("Invalid combination of options")
54-
# endregion
60+
# Determine triangulation strategy
61+
left_right_codes = _get_triangulation_codes(octree_list, options)
5562

63+
# Generate masks for all scalar fields
5664
all_mask_arrays: np.ndarray = _mask_generation(
5765
octree_leaves=octree_leaves,
5866
masking_option=options.evaluation_options.mesh_extraction_masking_options
5967
)
6068

69+
# Process each scalar field
6170
all_stack_intersection = []
6271
all_valid_edges = []
6372
all_left_right_codes = []
6473

6574
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
66-
previous_stack_is_onlap = data_descriptor.stack_relation[n_scalar_field - 1] == 'Onlap'
67-
was_erosion_before = data_descriptor.stack_relation[n_scalar_field - 1] == 'Erosion'
68-
if previous_stack_is_onlap and was_erosion_before: # ? (July, 2023) Is this still valid? I thought we have all the combinations
69-
raise NotImplementedError("Erosion and Onlap are not supported yet")
70-
pass
75+
_validate_stack_relations(data_descriptor, n_scalar_field)
7176

7277
mask: np.ndarray = all_mask_arrays[n_scalar_field]
73-
74-
if mask is not None and left_right_codes is not None:
75-
left_right_codes_per_stack = left_right_codes[mask]
76-
else:
77-
left_right_codes_per_stack = left_right_codes
78+
left_right_codes_per_stack = _get_masked_codes(left_right_codes, mask)
7879

7980
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
8081
intersection_xyz, valid_edges = find_intersection_on_edge(
@@ -88,101 +89,214 @@ def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpola
8889
all_valid_edges.append(valid_edges)
8990
all_left_right_codes.append(left_right_codes_per_stack)
9091

91-
output_on_edges = _interp_on_edges(all_stack_intersection, data_descriptor, dual_contouring_options, interpolation_input)
92+
# Interpolate on edges for all stacks
93+
output_on_edges = _interp_on_edges(
94+
all_stack_intersection, data_descriptor, dual_contouring_options, interpolation_input
95+
)
9296

97+
# Generate meshes for each scalar field
9398
for n_scalar_field in range(data_descriptor.stack_structure.n_stacks):
9499
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
100+
mask = all_mask_arrays[n_scalar_field]
101+
95102
dc_data = DualContouringData(
96103
xyz_on_edge=all_stack_intersection[n_scalar_field],
97104
valid_edges=all_valid_edges[n_scalar_field],
98-
xyz_on_centers=octree_leaves.grid_centers.octree_grid.values if mask is None else octree_leaves.grid_centers.octree_grid.values[mask],
105+
xyz_on_centers=(
106+
octree_leaves.grid_centers.octree_grid.values if mask is None
107+
else octree_leaves.grid_centers.octree_grid.values[mask]
108+
),
99109
dxdydz=octree_leaves.grid_centers.octree_dxdydz,
100110
exported_fields_on_edges=output_on_edges[n_scalar_field].exported_fields,
101111
n_surfaces_to_export=output.scalar_field_at_sp.shape[0],
102112
tree_depth=options.number_octree_levels,
103113
)
114+
104115
meshes: List[DualContouringMesh] = compute_dual_contouring(
105116
dc_data_per_stack=dc_data,
106117
left_right_codes=all_left_right_codes[n_scalar_field],
107118
debug=options.debug
108119
)
109120

110-
# ! If the order of the meshes does not match the order of scalar_field_at_surface points we need to reorder them HERE
111-
121+
# TODO: If the order of the meshes does not match the order of scalar_field_at_surface points, reorder them here
112122
if meshes is not None:
113123
all_meshes.extend(meshes)
114-
# @on
115124

116125
return all_meshes
117126

118127

119-
def _interp_on_edges(all_stack_intersection: list[Any], data_descriptor: InputDataDescriptor, dual_contouring_options: InterpolationOptions, interpolation_input: InterpolationInput) -> list[InterpOutput]:
128+
def _get_triangulation_codes(octree_list: List[OctreeLevel], options: InterpolationOptions) -> np.ndarray | None:
129+
"""
130+
Determine the appropriate triangulation codes based on options and octree structure.
131+
132+
Args:
133+
octree_list: List of octree levels
134+
options: Interpolation options
135+
136+
Returns:
137+
Left-right codes array if fancy triangulation is enabled and supported, None otherwise
138+
"""
139+
is_pure_octree = bool(np.all(octree_list[0].grid_centers.octree_grid_shape == 2))
140+
141+
match (options.evaluation_options.mesh_extraction_fancy, is_pure_octree):
142+
case (True, True):
143+
return get_left_right_array(octree_list)
144+
case (True, False):
145+
warnings.warn(
146+
"Fancy triangulation only works with regular grid of resolution [2,2,2]. "
147+
"Defaulting to regular triangulation"
148+
)
149+
return None
150+
case (False, _):
151+
return None
152+
case _:
153+
raise ValueError("Invalid combination of options")
154+
155+
156+
def _validate_stack_relations(data_descriptor: InputDataDescriptor, n_scalar_field: int) -> None:
157+
"""
158+
Validate stack relations for the given scalar field.
159+
160+
Args:
161+
data_descriptor: Input data descriptor containing stack relations
162+
n_scalar_field: Current scalar field index
163+
164+
Raises:
165+
NotImplementedError: If unsupported combination of Erosion and Onlap is detected
166+
"""
167+
if n_scalar_field == 0:
168+
return
169+
170+
previous_stack_is_onlap = data_descriptor.stack_relation[n_scalar_field - 1] == 'Onlap'
171+
was_erosion_before = data_descriptor.stack_relation[n_scalar_field - 1] == 'Erosion'
172+
173+
if previous_stack_is_onlap and was_erosion_before:
174+
# TODO (July, 2023): Is this still valid? I thought we have all the combinations
175+
raise NotImplementedError("Erosion and Onlap are not supported yet")
176+
177+
178+
def _get_masked_codes(left_right_codes: np.ndarray | None, mask: np.ndarray | None) -> np.ndarray | None:
179+
"""
180+
Apply mask to left-right codes if both are available.
181+
182+
Args:
183+
left_right_codes: Original left-right codes array
184+
mask: Boolean mask array
185+
186+
Returns:
187+
Masked codes if both inputs are not None, otherwise original codes
188+
"""
189+
if mask is not None and left_right_codes is not None:
190+
return left_right_codes[mask]
191+
return left_right_codes
192+
193+
194+
def _interp_on_edges(
195+
all_stack_intersection: List[Any],
196+
data_descriptor: InputDataDescriptor,
197+
dual_contouring_options: InterpolationOptions,
198+
interpolation_input: InterpolationInput
199+
) -> List[InterpOutput]:
200+
"""
201+
Interpolate scalar fields on edge intersection points.
202+
203+
Args:
204+
all_stack_intersection: List of intersection points for all stacks
205+
data_descriptor: Input data descriptor
206+
dual_contouring_options: Dual contouring specific options
207+
interpolation_input: Interpolation input data
208+
209+
Returns:
210+
List of interpolation outputs for each stack
211+
"""
120212
from ...core.data.engine_grid import EngineGrid
121213
from ...core.data.generic_grid import GenericGrid
122214
from ..interp_single.interp_features import interpolate_all_fields_no_octree
215+
216+
# Set temporary grid with concatenated intersection points
123217
interpolation_input.set_temp_grid(
124218
EngineGrid(
125219
custom_grid=GenericGrid(
126220
values=BackendTensor.t.concatenate(all_stack_intersection, axis=0)
127221
)
128222
)
129223
)
130-
# endregion
131224

132-
# ! (@miguel 21 June) I think by definition in the function `interpolate_all_fields_no_octree`
133-
# ! we just need to interpolate up to the n_scalar_field, but I am not sure about this. I need to test it
225+
# TODO (@miguel 21 June): By definition in `interpolate_all_fields_no_octree`
226+
# we just need to interpolate up to the n_scalar_field, but need to test this
227+
# This should be done with buffer weights to avoid waste
134228
output_on_edges: List[InterpOutput] = interpolate_all_fields_no_octree(
135229
interpolation_input=interpolation_input,
136230
options=dual_contouring_options,
137231
data_descriptor=data_descriptor
138-
) # ! This has to be done with buffer weights otherwise is a waste
232+
)
233+
234+
# Restore original grid
139235
interpolation_input.set_grid_to_original()
140236
return output_on_edges
141237

142238

143-
def _mask_generation(octree_leaves, masking_option: MeshExtractionMaskingOptions) -> np.ndarray | None:
144-
all_scalar_fields_outputs: list[InterpOutput] = octree_leaves.outputs_centers
239+
def _mask_generation(
240+
octree_leaves: OctreeLevel,
241+
masking_option: MeshExtractionMaskingOptions
242+
) -> np.ndarray | None:
243+
"""
244+
Generate masks for mesh extraction based on masking options and stack relations.
245+
246+
Args:
247+
octree_leaves: Octree leaf level containing scalar field outputs
248+
masking_option: Mesh extraction masking configuration
249+
250+
Returns:
251+
Matrix of boolean masks for each scalar field
252+
253+
Raises:
254+
NotImplementedError: For unsupported masking options
255+
ValueError: For invalid option combinations
256+
"""
257+
all_scalar_fields_outputs: List[InterpOutput] = octree_leaves.outputs_centers
145258
n_scalar_fields = len(all_scalar_fields_outputs)
146259
outputs_ = all_scalar_fields_outputs[0]
147260
slice_corners = outputs_.grid.corners_grid_slice
148261
grid_size = outputs_.cornersGrid_values.shape[0]
262+
149263
mask_matrix = BackendTensor.t.zeros((n_scalar_fields, grid_size // 8), dtype=bool)
150264
onlap_chain_counter = 0
151265

152266
for i in range(n_scalar_fields):
153267
stack_relation = all_scalar_fields_outputs[i].scalar_fields.stack_relation
268+
154269
match (masking_option, stack_relation):
155270
case MeshExtractionMaskingOptions.RAW, _:
156271
mask_matrix[i] = BackendTensor.t.ones(grid_size // 8, dtype=bool)
272+
157273
case MeshExtractionMaskingOptions.DISJOINT, _:
158-
raise NotImplementedError("Disjoint is not supported yet. Not even sure if there is anything to support")
159-
# case (MeshExtractionMaskingOptions.DISJOINT | MeshExtractionMaskingOptions.INTERSECT, StackRelationType.FAULT):
160-
# mask_matrix[i] = np.ones(grid_size//8, dtype=bool)
161-
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ERODE | StackRelationType.BASEMENT:
162-
# mask_scalar = all_scalar_fields_outputs[i - 1].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
163-
# if MaskBuffer.previous_mask is None:
164-
# mask = mask_scalar
165-
# else:
166-
# mask = (MaskBuffer.previous_mask ^ mask_scalar) * mask_scalar
167-
# MaskBuffer.previous_mask = mask
168-
# case MeshExtractionMaskingOptions.DISJOINT, StackRelationType.ONLAP:
169-
# raise NotImplementedError("Onlap is not supported yet")
170-
# return octree_leaves.outputs_corners[n_scalar_field].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
274+
raise NotImplementedError(
275+
"Disjoint is not supported yet. Not even sure if there is anything to support"
276+
)
277+
171278
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ERODE:
172-
x = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
279+
mask_array = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array
280+
x = mask_array[slice_corners].reshape((1, -1, 8))
173281
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
174282
onlap_chain_counter = 0
283+
175284
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.BASEMENT:
176-
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
285+
mask_array = all_scalar_fields_outputs[i].squeezed_mask_array
286+
x = mask_array[slice_corners].reshape((1, -1, 8))
177287
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
178288
onlap_chain_counter = 0
289+
179290
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ONLAP:
180-
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
291+
mask_array = all_scalar_fields_outputs[i].squeezed_mask_array
292+
x = mask_array[slice_corners].reshape((1, -1, 8))
181293
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
182294
onlap_chain_counter += 1
295+
183296
case _, StackRelationType.FAULT:
184297
mask_matrix[i] = BackendTensor.t.ones(grid_size // 8, dtype=bool)
298+
185299
case _:
186300
raise ValueError("Invalid combination of options")
187301

188-
return mask_matrix
302+
return mask_matrix

0 commit comments

Comments
 (0)