Skip to content

Commit 4e2833d

Browse files
authored
[REFACTOR] Use center grid for dual contouring instead of corners grid (#24)
# Combining centers and corners This PR refactors the dual contouring implementation to use a more efficient approach for handling grid corners: - Adds `corners_grid` property to `EngineGrid` to store corner values alongside centers - Modifies `interpolate_on_octree` to compute corners as part of the center grid rather than separately - Adds corner-related properties to `InterpOutput` for accessing corner data - Fixes `__len__` in `RegularGrid` to return the correct number of active cells - Replaces `_get_intersection_on_edges` with direct calls to `find_intersection_on_edge` for improved consistency - Updates masking logic to work with the new grid structure - Simplifies octree generation by removing redundant corner-related code These changes improve the efficiency of the dual contouring algorithm by reducing redundant computations and simplifying the data flow.
2 parents d4419c3 + e632517 commit 4e2833d

File tree

11 files changed

+219
-252
lines changed

11 files changed

+219
-252
lines changed

gempy_engine/API/dual_contouring/_interpolate_on_edges.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Tuple, Optional
1+
from typing import List, Optional
22

33
import numpy as np
44

@@ -24,10 +24,15 @@ def interpolate_on_edges_for_dual_contouring(
2424
octree_leaves: OctreeLevel,
2525
mask: Optional[np.ndarray] = None
2626
) -> DualContouringData:
27-
28-
# region define location where we need to interpolate the gradients for dual contouring
29-
output_corners: InterpOutput = octree_leaves.outputs_corners[n_scalar_field]
30-
intersection_xyz, valid_edges = _get_intersection_on_edges(octree_leaves, output_corners, mask)
27+
28+
output: InterpOutput = octree_leaves.outputs_centers[n_scalar_field]
29+
intersection_xyz, valid_edges = find_intersection_on_edge(
30+
_xyz_corners=octree_leaves.grid_centers.corners_grid.values,
31+
scalar_field_on_corners=output.exported_fields.scalar_field[output.grid.corners_grid_slice],
32+
scalar_at_sp=output.scalar_field_at_sp,
33+
masking=mask
34+
)
35+
3136
interpolation_input.set_temp_grid(EngineGrid(custom_grid=GenericGrid(values=intersection_xyz)))
3237
# endregion
3338

@@ -43,22 +48,7 @@ def interpolate_on_edges_for_dual_contouring(
4348
xyz_on_centers=octree_leaves.grid_centers.octree_grid.values if mask is None else octree_leaves.grid_centers.octree_grid.values[mask],
4449
dxdydz=octree_leaves.grid_centers.octree_dxdydz,
4550
exported_fields_on_edges=output_on_edges[n_scalar_field].exported_fields,
46-
n_surfaces_to_export=output_corners.scalar_field_at_sp.shape[0],
51+
n_surfaces_to_export=output.scalar_field_at_sp.shape[0],
4752
tree_depth=options.number_octree_levels,
4853
)
4954
return dc_data
50-
51-
52-
# TODO: These two functions could be moved to the module
53-
def _get_intersection_on_edges(octree_level: OctreeLevel, output_corners: InterpOutput,
54-
mask: Optional[np.ndarray] = None) -> Tuple[np.ndarray, np.ndarray]:
55-
# First find xyz on edges:
56-
intersection_xyz, valid_edges = find_intersection_on_edge(
57-
_xyz_corners=octree_level.grid_corners.values,
58-
scalar_field_on_corners=output_corners.exported_fields.scalar_field,
59-
scalar_at_sp=output_corners.scalar_field_at_sp,
60-
masking=mask
61-
)
62-
return intersection_xyz, valid_edges
63-
64-

gempy_engine/API/dual_contouring/multi_scalar_dual_contouring.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@ def dual_contouring_multi_scalar(data_descriptor: InputDataDescriptor, interpola
9797

9898

9999
def _mask_generation(octree_leaves, masking_option: MeshExtractionMaskingOptions) -> np.ndarray | None:
100-
all_scalar_fields_outputs: list[InterpOutput] = octree_leaves.outputs_corners
100+
all_scalar_fields_outputs: list[InterpOutput] = octree_leaves.outputs_centers
101101
n_scalar_fields = len(all_scalar_fields_outputs)
102-
grid_size = all_scalar_fields_outputs[0].grid_size
102+
outputs_ = all_scalar_fields_outputs[0]
103+
slice_corners = outputs_.grid.corners_grid_slice
104+
grid_size = outputs_.cornersGrid_values.shape[0]
103105
mask_matrix = BackendTensor.t.zeros((n_scalar_fields, grid_size // 8), dtype=bool)
104106
onlap_chain_counter = 0
105107

@@ -123,15 +125,15 @@ def _mask_generation(octree_leaves, masking_option: MeshExtractionMaskingOptions
123125
# raise NotImplementedError("Onlap is not supported yet")
124126
# return octree_leaves.outputs_corners[n_scalar_field].squeezed_mask_array.reshape((1, -1, 8)).sum(-1, bool)[0]
125127
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ERODE:
126-
x = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array.reshape((1, -1, 8))
128+
x = all_scalar_fields_outputs[i + onlap_chain_counter].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
127129
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
128130
onlap_chain_counter = 0
129131
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.BASEMENT:
130-
x = all_scalar_fields_outputs[i].squeezed_mask_array.reshape((1, -1, 8))
132+
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
131133
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
132134
onlap_chain_counter = 0
133135
case MeshExtractionMaskingOptions.INTERSECT, StackRelationType.ONLAP:
134-
x = all_scalar_fields_outputs[i].squeezed_mask_array.reshape((1, -1, 8))
136+
x = all_scalar_fields_outputs[i].squeezed_mask_array[slice_corners].reshape((1, -1, 8))
135137
mask_matrix[i] = BackendTensor.t.sum(x, -1, bool)[0]
136138
onlap_chain_counter += 1
137139
case _, StackRelationType.FAULT:

gempy_engine/API/interp_single/_octree_generation.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818

1919

20-
def interpolate_on_octree(interpolation_input: InterpolationInput, options: InterpolationOptions,
20+
def interpolate_on_octree_(interpolation_input: InterpolationInput, options: InterpolationOptions,
2121
data_shape: InputDataDescriptor) -> OctreeLevel:
2222
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and NOT_MAKE_INPUT_DEEP_COPY is False:
2323
temp_interpolation_input = copy.deepcopy(interpolation_input)
@@ -30,8 +30,9 @@ def interpolate_on_octree(interpolation_input: InterpolationInput, options: Inte
3030
# * Interpolate - corners
3131
grid_0_centers: EngineGrid = temp_interpolation_input.grid # ? This could be moved to the next section
3232
if options.compute_corners:
33-
grid_0_corners: Optional[EngineGrid] = EngineGrid.from_xyz_coords(
34-
xyz_coords=_generate_corners(regular_grid=grid_0_centers.octree_grid)
33+
xyz_corners = _generate_corners(regular_grid=grid_0_centers.octree_grid)
34+
grid_0_corners: EngineGrid = EngineGrid.from_xyz_coords(
35+
xyz_coords=xyz_corners
3536
)
3637

3738
# ! Here we need to swap the grid temporarily but it is
@@ -58,6 +59,51 @@ def interpolate_on_octree(interpolation_input: InterpolationInput, options: Inte
5859
return next_octree_level
5960

6061

62+
def interpolate_on_octree(interpolation_input: InterpolationInput, options: InterpolationOptions,
63+
data_shape: InputDataDescriptor) -> OctreeLevel:
64+
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and NOT_MAKE_INPUT_DEEP_COPY is False:
65+
temp_interpolation_input = copy.deepcopy(interpolation_input)
66+
else:
67+
temp_interpolation_input = interpolation_input
68+
69+
# * Interpolate - corners
70+
if options.compute_corners:
71+
grid_0_centers: EngineGrid = temp_interpolation_input.grid # ? This could be moved to the next section
72+
xyz_corners = _generate_corners(regular_grid=grid_0_centers.octree_grid)
73+
74+
corner_grid = GenericGrid(values=xyz_corners)
75+
grid_0_centers.corners_grid = corner_grid
76+
output_0_centers: List[InterpOutput] = interpolate_all_fields(temp_interpolation_input, options, data_shape) # interpolate - centers
77+
78+
# * DEP
79+
grid_0_corners = None
80+
output_0_corners = []
81+
82+
# * Create next octree level
83+
next_octree_level = OctreeLevel(
84+
grid_centers=grid_0_centers,
85+
grid_corners=grid_0_corners,
86+
outputs_centers=output_0_centers,
87+
outputs_corners=output_0_corners
88+
)
89+
else:
90+
grid_0_centers: EngineGrid = temp_interpolation_input.grid # ? This could be moved to the next section
91+
output_0_centers: List[InterpOutput] = interpolate_all_fields(temp_interpolation_input, options, data_shape) # interpolate - centers
92+
93+
# * DEP
94+
output_0_corners = []
95+
grid_0_corners = None
96+
97+
# * Create next octree level
98+
next_octree_level = OctreeLevel(
99+
grid_centers=grid_0_centers,
100+
grid_corners=grid_0_corners,
101+
outputs_centers=output_0_centers,
102+
outputs_corners=output_0_corners
103+
)
104+
105+
return next_octree_level
106+
61107
def _generate_corners_DEP(regular_grid: RegularGrid, level=1) -> np.ndarray:
62108
if regular_grid is None: raise ValueError("Regular grid is None")
63109

gempy_engine/core/data/engine_grid.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,23 @@ class EngineGrid:
2222
topography: Optional[GenericGrid] = None
2323
sections: Optional[GenericGrid] = None
2424
geophysics_grid: Optional[CenteredGrid] = None # TODO: Not implemented this probably will need something different that the generic grid?
25+
corners_grid: Optional[GenericGrid] = None # TODO: Not implemented this probably will need something different that the generic grid?
2526

2627
debug_vals = None
2728

2829
# ? Should we add the number of octrees here instead of the general options
2930

3031
def __init__(self, octree_grid: Optional[RegularGrid] = None, dense_grid: Optional[RegularGrid] = None,
3132
custom_grid: Optional[GenericGrid] = None, topography: Optional[GenericGrid] = None,
32-
sections: Optional[GenericGrid] = None, geophysics_grid: Optional[CenteredGrid] = None):
33+
sections: Optional[GenericGrid] = None, geophysics_grid: Optional[CenteredGrid] = None,
34+
corners_grid: Optional[GenericGrid] = None):
3335
self.octree_grid = octree_grid
3436
self.dense_grid = dense_grid
3537
self.custom_grid = custom_grid
3638
self.topography = topography
3739
self.sections = sections
3840
self.geophysics_grid = geophysics_grid
41+
self.corners_grid = corners_grid
3942

4043
@property
4144
def regular_grid(self):
@@ -78,6 +81,8 @@ def values(self) -> np.ndarray:
7881
values.append(self.sections.values)
7982
if self.geophysics_grid is not None:
8083
values.append(self.geophysics_grid.values)
84+
if self.corners_grid is not None:
85+
values.append(self.corners_grid.values)
8186

8287
values_array = BackendTensor.t.concatenate(values, dtype=BackendTensor.dtype)
8388
values_array = BackendTensor.t.array(values_array, dtype=BackendTensor.dtype)
@@ -131,6 +136,14 @@ def geophysics_grid_slice(self) -> slice:
131136
start + len(self.geophysics_grid) if self.geophysics_grid is not None else start
132137
)
133138

139+
@property
140+
def corners_grid_slice(self) -> slice:
141+
start = self.geophysics_grid_slice.stop
142+
return slice(
143+
start,
144+
start + len(self.corners_grid) if self.corners_grid is not None else start
145+
)
146+
134147
@property
135148
def len_all_grids(self) -> int:
136149
return self.values.shape[0]

gempy_engine/core/data/interp_output.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ def custom_grid_values(self):
8484
def geophysics_grid_values(self):
8585
return self.block[self.grid.geophysics_grid_slice]
8686

87+
@property
88+
def cornersGrid_values(self):
89+
return self.block[self.grid.corners_grid_slice]
90+
91+
@property
92+
def ids_cornersGrid(self):
93+
return BackendTensor.t.rint(self.block[self.grid.corners_grid_slice])
94+
8795
@property
8896
def ids_geophysics_grid(self):
8997
return BackendTensor.t.rint(self.block[self.grid.geophysics_grid_slice])
@@ -132,6 +140,21 @@ def litho_faults_ids(self):
132140
# Generate the unique IDs
133141
unique_ids = litho_ids + faults_ids * multiplier
134142
return unique_ids
143+
144+
@property
145+
def litho_faults_ids_corners_grid(self):
146+
if self.combined_scalar_field is None: # * This in principle is only used for testing
147+
return self.ids_cornersGrid
148+
149+
litho_ids = BackendTensor.t.rint(self.block[self.grid.corners_grid_slice])
150+
faults_ids = BackendTensor.t.rint(self.faults_block[self.grid.corners_grid_slice])
151+
152+
# Get the number of unique lithology IDs
153+
multiplier = len(BackendTensor.t.unique(litho_ids))
154+
155+
# Generate the unique IDs
156+
unique_ids = litho_ids + faults_ids * multiplier
157+
return unique_ids
135158

136159
def get_block_from_value_type(self, value_type: ValueType, slice_: slice):
137160
match value_type:

gempy_engine/core/data/octree_level.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def output_corners(self): # * Alias
4848
@property
4949
def last_output_corners(self):
5050
return self.outputs_corners[-1]
51-
51+
52+
@property
53+
def litho_faults_ids_corners_grid(self):
54+
return self.outputs_centers[-1].litho_faults_ids_corners_grid
55+
5256
@property
5357
def number_of_outputs(self):
5458
return len(self.outputs_centers)

gempy_engine/core/data/regular_grid.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ class RegularGrid:
2020
original_values: np.ndarray = field(default=None, repr=False, init=False) #: When the regular grid is representing a octree level, only active cells are stored in values. This is the original values of the regular grid.
2121

2222
def __len__(self):
23-
return self.regular_grid_shape.prod()
23+
# return self.regular_grid_shape.prod()
24+
return self.values.shape[0]
2425

2526
def __post_init__(self):
2627
self.regular_grid_shape = _check_and_convert_list_to_array(self.regular_grid_shape)

gempy_engine/modules/octrees_topology/_octree_internals.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def compute_next_octree_locations(prev_octree: OctreeLevel, evaluation_options: EvaluationOptions,
1616
current_octree_level: int) -> EngineGrid:
17-
ids = prev_octree.last_output_corners.litho_faults_ids
17+
ids = prev_octree.litho_faults_ids_corners_grid
1818
uv_8 = ids.reshape((-1, 8))
1919

2020
# Old octree

tests/test_common/test_integrations/test_multi_fields.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def test_plot_corners(unconformity_complex, n_oct_levels=2):
147147
interpolation_input, options, structure = unconformity_complex
148148
options.number_octree_levels = n_oct_levels
149149
solutions: Solutions = compute_model(interpolation_input, options, structure)
150-
output_corners: InterpOutput = solutions.octrees_output[-1].outputs_corners[-1]
151-
152-
vertices = output_corners.grid.values
150+
vertices = solutions.octrees_output[-1].grid_centers.corners_grid.values
153151
if plot_pyvista or False:
154152
helper_functions_pyvista.plot_pyvista(solutions.octrees_output, v_just_points=vertices)
155153

0 commit comments

Comments
 (0)