Skip to content

Commit 95a580a

Browse files
authored
[ENH] Add model serialization validation and refactor grid modules (#1032)
# Description Refactored grid modules to improve serialization and added validation capabilities to the model computation process. Key changes include: - Split grid_types.py into separate files for each grid type (RegularGrid, CustomGrid, Sections) - Converted CustomGrid and Sections to dataclasses with proper field definitions - Added serialization validation during model computation via environment variable - Fixed parameter names in grid API functions to match updated class structures - Added dotenv support to compute_API for configuration - Updated geophysics_input to be properly serializable - Added verification tests for model serialization in multiple test cases - Added approved verification files for serialization tests Relates to #serialization-improvements # Checklist - [x] My code uses type hinting for function and method arguments and return values. - [x] I have created tests which cover my code. - [x] The test code either 1. demonstrates at least one valuable use case (e.g. integration tests) or 2. verifies that outputs are as expected for given inputs (e.g. unit tests). - [x] New tests pass locally with my changes.
2 parents 2be0238 + c5041f4 commit 95a580a

File tree

22 files changed

+1133
-217
lines changed

22 files changed

+1133
-217
lines changed

docs/developers_notes/dev_log/2025_05.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,8 @@
2626
- Solutions (Not sure if I want to save them)
2727
- RawArraysSolutions
2828
- Solutions
29+
- Testing
30+
- [ ] 10 Test to go in modules
31+
- [ ] test api
2932

3033
## What do I have in the engine server logic?

gempy/API/compute_API.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from typing import Optional
1+
import dotenv
2+
import os
3+
4+
from typing import Optional
25

36
import numpy as np
47

@@ -14,8 +17,11 @@
1417
from ..modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
1518
from ..optional_dependencies import require_gempy_legacy
1619

20+
dotenv.load_dotenv()
21+
1722

18-
def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConfig] = None) -> Solutions:
23+
def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConfig] = None,
24+
**kwargs) -> Solutions:
1925
"""
2026
Compute the geological model given the provided GemPy model.
2127
@@ -56,6 +62,12 @@ def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConf
5662
case _:
5763
raise ValueError(f'Backend {engine_config} not supported')
5864

65+
if os.getenv("VALIDATE_SERIALIZATION", False) and kwargs.get("validate_serialization", True):
66+
from ..modules.serialization.save_load import save_model
67+
import tempfile
68+
with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp:
69+
save_model(model=gempy_model, path=tmp.name, validate_serialization=True)
70+
5971
return gempy_model.solutions
6072

6173

@@ -79,7 +91,7 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
7991
xyz_coord=at
8092
)
8193

82-
sol = compute_model(gempy_model, engine_config)
94+
sol = compute_model(gempy_model, engine_config, validate_serialization=False)
8395
return sol.raw_arrays.custom
8496

8597

gempy/API/grid_API.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212
def set_section_grid(grid: Grid, section_dict: dict):
1313
if grid.sections is None:
14-
grid.sections = Sections(regular_grid=grid.regular_grid, section_dict=section_dict)
14+
grid.sections = Sections(
15+
z_ext=grid.regular_grid.extent[4:],
16+
section_dict=section_dict
17+
)
1518
else:
1619
grid.sections.set_sections(section_dict,
1720
regular_grid=grid.regular_grid)
@@ -54,9 +57,9 @@ def set_topography_from_random(grid: Grid, fractal_dimension: float = 2.0, d_z:
5457
dz=d_z,
5558
fractal_dimension=fractal_dimension
5659
)
57-
60+
5861
grid.topography = Topography(
59-
regular_grid=grid.regular_grid,
62+
_regular_grid=grid.regular_grid,
6063
values_2d=random_topography
6164
)
6265

@@ -70,7 +73,7 @@ def set_topography_from_subsurface_structured_grid(grid: Grid, struct: "subsurfa
7073
return grid.topography
7174

7275

73-
def set_topography_from_arrays(grid: Grid, xyz_vertices: np.ndarray):
76+
def set_topography_from_arrays(grid: Grid, xyz_vertices: np.ndarray):
7477
grid.topography = Topography.from_unstructured_mesh(grid.regular_grid, xyz_vertices)
7578
set_active_grid(grid, [Grid.GridTypes.TOPOGRAPHY])
7679
return grid.topography
@@ -86,9 +89,9 @@ def set_topography_from_file(grid: Grid, filepath: str, crop_to_extent: Union[Se
8689

8790

8891
def set_custom_grid(grid: Grid, xyz_coord: np.ndarray):
89-
custom_grid = CustomGrid(xyx_coords=xyz_coord)
92+
custom_grid = CustomGrid(values=xyz_coord)
9093
grid.custom_grid = custom_grid
91-
94+
9295
set_active_grid(grid, [Grid.GridTypes.CUSTOM])
9396
return grid.custom_grid
9497

gempy/core/data/encoders/converters.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Annotated
2+
13
from contextlib import contextmanager
24

35
from contextvars import ContextVar
@@ -17,6 +19,9 @@ def validate_numpy_array(v):
1719
return np.array(v) if v is not None else None
1820

1921

22+
short_array_type = Annotated[np.ndarray, (BeforeValidator(lambda v: np.array(v) if v is not None else None))]
23+
24+
2025
def instantiate_if_necessary(data: dict, key: str, type: type) -> None:
2126
"""
2227
Creates instances of the specified type for a dictionary key if the key exists and its

gempy/core/data/geo_model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from gempy_engine.core.data.interpolation_input import InterpolationInput
1616
from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution
1717
from gempy_engine.core.data.transforms import Transform, GlobalAnisotropy
18+
from gempy_engine.modules.geophysics.gravity_gradient import calculate_gravity_gradient
1819
from .encoders.converters import instantiate_if_necessary
1920
from .encoders.json_geomodel_encoder import encode_numpy_array
2021
from .grid import Grid
@@ -23,6 +24,7 @@
2324
from .surface_points import SurfacePointsTable
2425
from ...modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
2526

27+
2628
"""
2729
TODO:
2830
- [ ] StructuralFrame will all input points chunked on Elements. Here I will need a property to put all
@@ -62,7 +64,7 @@ class GeoModel(BaseModel):
6264
# region GemPy engine data types
6365
_interpolation_options: InterpolationOptions #: The interpolation options provided by the user.
6466

65-
geophysics_input: GeophysicsInput = Field(default=None, exclude=True) #: The geophysics input of the geological model.
67+
geophysics_input: GeophysicsInput | None = Field(default=None, exclude=False) #: The geophysics input of the geological model.
6668
input_transform: Transform = Field(default=None, exclude=False) #: The transformation used in the geological model for input points.
6769

6870
interpolation_grid: EngineGrid = Field(default=None, exclude=True) #: ptional grid used for interpolation. Can be seen as a cache field.
@@ -295,24 +297,30 @@ def add_surface_points(self, X: Sequence[float], Y: Sequence[float], Z: Sequence
295297
arbitrary_types_allowed=True,
296298
use_enum_values=False,
297299
json_encoders={
298-
np.ndarray: encode_numpy_array
300+
np.ndarray: encode_numpy_array,
299301
}
300302
)
301303

302304
@model_validator(mode='wrap')
303305
@classmethod
304306
def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: ModelWrapValidatorHandler["GeoModel"]) -> "GeoModel":
305307
match data:
306-
case GeoModel():
308+
case GeoModel():
307309
return data
308-
case dict():
310+
case dict(): #
309311
instance: GeoModel = constructor(data)
310312
instantiate_if_necessary(
311313
data=data,
312314
key="_interpolation_options",
313315
type=InterpolationOptions
314316
)
315317
instance._interpolation_options = data.get("_interpolation_options")
318+
319+
# * Reset geophysics if necessary
320+
centered_grid = instance.grid.centered_grid
321+
if centered_grid is not None and instance.geophysics_input is not None:
322+
instance.geophysics_input.tz = calculate_gravity_gradient(centered_grid)
323+
316324
return instance
317325
case _:
318326
raise ValidationError
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
from .grid_types import Sections, RegularGrid, CustomGrid
1+
from .regular_grid import RegularGrid
2+
from .custom_grid import CustomGrid
3+
from .sections_grid import Sections
24
from .topography import Topography
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import dataclasses
2+
import numpy as np
3+
from pydantic import Field
4+
5+
6+
@dataclasses.dataclass
7+
class CustomGrid:
8+
"""Object that contains arbitrary XYZ coordinates.
9+
10+
Args:
11+
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates
12+
13+
Attributes:
14+
values (np.ndarray): XYZ coordinates
15+
"""
16+
17+
values: np.ndarray = Field(
18+
exclude=True,
19+
default_factory=lambda: np.zeros((0, 3)),
20+
repr=False
21+
)
22+
23+
24+
def __post_init__(self):
25+
custom_grid = np.atleast_2d(self.values)
26+
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
27+
'The shape of new grid must be (n,3) where n is the number of' \
28+
' points of the grid'
29+
30+
31+
@property
32+
def length(self):
33+
return self.values.shape[0]

gempy/core/data/grid_modules/grid_types.py renamed to gempy/core/data/grid_modules/regular_grid.py

Lines changed: 1 addition & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55

66
import numpy as np
77

8-
from ..core_utils import calculate_line_coordinates_2points
98
from ..encoders.converters import numpy_array_short_validator
109
from .... import optional_dependencies
11-
from ....optional_dependencies import require_pandas
12-
from gempy_engine.core.data.transforms import Transform, TransformOpsOrder
10+
from gempy_engine.core.data.transforms import Transform
1311

1412

1513
@dataclasses.dataclass
@@ -255,145 +253,3 @@ def plot_rotation(regular_grid, pivot, point_x_axis, point_y_axis):
255253
plt.show()
256254

257255

258-
class Sections:
259-
"""
260-
Object that creates a grid of cross sections between two points.
261-
262-
Args:
263-
regular_grid: Model.grid.regular_grid
264-
section_dict: {'section name': ([p1_x, p1_y], [p2_x, p2_y], [xyres, zres])}
265-
"""
266-
267-
def __init__(self, regular_grid=None, z_ext=None, section_dict=None):
268-
pd = require_pandas()
269-
if regular_grid is not None:
270-
self.z_ext = regular_grid.extent[4:]
271-
else:
272-
self.z_ext = z_ext
273-
274-
self.section_dict = section_dict
275-
self.names = []
276-
self.points = []
277-
self.resolution = []
278-
self.length = [0]
279-
self.dist = []
280-
self.df = pd.DataFrame()
281-
self.df['dist'] = self.dist
282-
self.values = np.empty((0, 3))
283-
self.extent = None
284-
285-
if section_dict is not None:
286-
self.set_sections(section_dict)
287-
288-
def _repr_html_(self):
289-
return self.df.to_html()
290-
291-
def __repr__(self):
292-
return self.df.to_string()
293-
294-
def show(self):
295-
pass
296-
297-
def set_sections(self, section_dict, regular_grid=None, z_ext=None):
298-
pd = require_pandas()
299-
self.section_dict = section_dict
300-
if regular_grid is not None:
301-
self.z_ext = regular_grid.extent[4:]
302-
303-
self.names = np.array(list(self.section_dict.keys()))
304-
305-
self.get_section_params()
306-
self.calculate_all_distances()
307-
self.df = pd.DataFrame.from_dict(self.section_dict, orient='index', columns=['start', 'stop', 'resolution'])
308-
self.df['dist'] = self.dist
309-
310-
self.compute_section_coordinates()
311-
312-
def get_section_params(self):
313-
self.points = []
314-
self.resolution = []
315-
self.length = [0]
316-
317-
for i, section in enumerate(self.names):
318-
points = [self.section_dict[section][0], self.section_dict[section][1]]
319-
assert points[0] != points[
320-
1], 'The start and end points of the section must not be identical.'
321-
322-
self.points.append(points)
323-
self.resolution.append(self.section_dict[section][2])
324-
self.length = np.append(self.length, self.section_dict[section][2][0] *
325-
self.section_dict[section][2][1])
326-
self.length = np.array(self.length).cumsum()
327-
328-
def calculate_all_distances(self):
329-
self.coordinates = np.array(self.points).ravel().reshape(-1,
330-
4) # axis are x1,y1,x2,y2
331-
self.dist = np.sqrt(np.diff(self.coordinates[:, [0, 2]]) ** 2 + np.diff(
332-
self.coordinates[:, [1, 3]]) ** 2)
333-
334-
def compute_section_coordinates(self):
335-
for i in range(len(self.names)):
336-
xy = calculate_line_coordinates_2points(self.coordinates[i, :2],
337-
self.coordinates[i, 2:],
338-
self.resolution[i][0])
339-
zaxis = np.linspace(self.z_ext[0], self.z_ext[1], self.resolution[i][1],
340-
dtype="float64")
341-
X, Z = np.meshgrid(xy[:, 0], zaxis, indexing='ij')
342-
Y, _ = np.meshgrid(xy[:, 1], zaxis, indexing='ij')
343-
xyz = np.vstack((X.flatten(), Y.flatten(), Z.flatten())).T
344-
if i == 0:
345-
self.values = xyz
346-
else:
347-
self.values = np.vstack((self.values, xyz))
348-
349-
def generate_axis_coord(self):
350-
for i, name in enumerate(self.names):
351-
xy = calculate_line_coordinates_2points(
352-
self.coordinates[i, :2],
353-
self.coordinates[i, 2:],
354-
self.resolution[i][0]
355-
)
356-
yield name, xy
357-
358-
def get_section_args(self, section_name: str):
359-
where = np.where(self.names == section_name)[0][0]
360-
return self.length[where], self.length[where + 1]
361-
362-
def get_section_grid(self, section_name: str):
363-
l0, l1 = self.get_section_args(section_name)
364-
return self.values[l0:l1]
365-
366-
367-
class CustomGrid:
368-
"""Object that contains arbitrary XYZ coordinates.
369-
370-
Args:
371-
xyx_coords (numpy.ndarray like): XYZ (in columns) of the desired coordinates
372-
373-
Attributes:
374-
values (np.ndarray): XYZ coordinates
375-
"""
376-
377-
def __init__(self, xyx_coords: np.ndarray):
378-
self.values = np.zeros((0, 3))
379-
self.set_custom_grid(xyx_coords)
380-
381-
def set_custom_grid(self, custom_grid: np.ndarray):
382-
"""
383-
Give the coordinates of an external generated grid
384-
385-
Args:
386-
custom_grid (numpy.ndarray like): XYZ (in columns) of the desired coordinates
387-
388-
Returns:
389-
numpy.ndarray: Unraveled 3D numpy array where every row correspond to the xyz coordinates of a regular
390-
grid
391-
"""
392-
custom_grid = np.atleast_2d(custom_grid)
393-
assert type(custom_grid) is np.ndarray and custom_grid.shape[1] == 3, \
394-
'The shape of new grid must be (n,3) where n is the number of' \
395-
' points of the grid'
396-
397-
self.values = custom_grid
398-
self.length = self.values.shape[0]
399-
return self.values

0 commit comments

Comments
 (0)