Skip to content

Commit 08e510d

Browse files
authored
[ENH] Add grid serialization and improve binary data handling (#1034)
# Description Enhanced model serialization to include grid data in binary format. This PR improves the serialization process by: 1. Adding support for serializing and deserializing grid data (custom grid and topography) 2. Improving binary encoding/decoding with proper length tracking for different data sections 3. Adding docstrings to binary encoder functions for better code documentation 4. Fixing type checking in `calculate_line_coordinates_2points` to support tuples in addition to lists 5. Enabling validation of serialization by default in compute functions 6. Updating tests to work with the new serialization format Relates to #serialization-enhancement # Checklist - [x] My code uses type hinting for function and method arguments and return values. - [x] I have created tests which cover my code. - [x] The test code either 1. demonstrates at least one valuable use case (e.g. integration tests) or 2. verifies that outputs are as expected for given inputs (e.g. unit tests). - [x] New tests pass locally with my changes.
2 parents d34e6e2 + 820f798 commit 08e510d

File tree

21 files changed

+486
-67
lines changed

21 files changed

+486
-67
lines changed

gempy/API/compute_API.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
9191
xyz_coord=at
9292
)
9393

94-
sol = compute_model(gempy_model, engine_config, validate_serialization=False)
94+
sol = compute_model(gempy_model, engine_config, validate_serialization=True)
9595
return sol.raw_arrays.custom
9696

9797

gempy/core/data/core_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55

66
def calculate_line_coordinates_2points(p1, p2, res):
7-
if isinstance(p1, list):
7+
if isinstance(p1, list) or isinstance(p1, tuple):
88
p1 = np.array(p1)
9-
if isinstance(p2, list):
9+
if isinstance(p2, list) or isinstance(p2, tuple):
1010
p2 = np.array(p2)
1111
v = p2 - p1 # vector pointing from p1 to p2
1212
u = v / np.linalg.norm(v) # normalize it

gempy/core/data/encoders/binary_encoder.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,68 @@
44
from ..orientations import OrientationsTable
55

66

7-
def deserialize_input_data_tables(binary_array: bytes, name_id_map: dict, sp_binary_length_: int) -> tuple[OrientationsTable, SurfacePointsTable]:
7+
def deserialize_input_data_tables(binary_array: bytes, name_id_map: dict,
8+
sp_binary_length_: int, ori_binary_length_: int) -> tuple[OrientationsTable, SurfacePointsTable]:
9+
"""
10+
Deserializes binary data into two tables: OrientationsTable and SurfacePointsTable.
11+
12+
This function takes a binary array, a mapping of names to IDs, and lengths for
13+
specific parts of the binary data to extract and deserialize two distinct data
14+
tables: OrientationsTable and SurfacePointsTable. It uses the provided lengths
15+
to split the binary data accordingly and reconstructs the table contents from
16+
their respective binary representations.
17+
18+
Args:
19+
binary_array (bytes): A bytes array containing the serialized data for
20+
both the OrientationsTable and SurfacePointsTable.
21+
name_id_map (dict): A dictionary mapping names to IDs which is used to
22+
help reconstruct the table objects.
23+
sp_binary_length_ (int): The length of the binary segment corresponding
24+
to the SurfacePointsTable data.
25+
ori_binary_length_ (int): The length of the binary segment corresponding
26+
to the OrientationsTable data.
27+
28+
Returns:
29+
tuple[OrientationsTable, SurfacePointsTable]: A tuple containing two table
30+
objects: first the OrientationsTable, and second the SurfacePointsTable.
31+
"""
832
sp_binary = binary_array[:sp_binary_length_]
9-
ori_binary = binary_array[sp_binary_length_:]
33+
ori_binary = binary_array[sp_binary_length_:sp_binary_length_+ori_binary_length_]
1034
# Reconstruct arrays
1135
sp_data: np.ndarray = np.frombuffer(sp_binary, dtype=SurfacePointsTable.dt)
1236
ori_data: np.ndarray = np.frombuffer(ori_binary, dtype=OrientationsTable.dt)
1337
surface_points_table = SurfacePointsTable(data=sp_data, name_id_map=name_id_map)
1438
orientations_table = OrientationsTable(data=ori_data, name_id_map=name_id_map)
1539
return orientations_table, surface_points_table
40+
41+
42+
def deserialize_grid(binary_array:bytes, custom_grid_length: int, topography_length: int) -> tuple[np.ndarray, np.ndarray]:
43+
"""
44+
Deserialize binary grid data into two numpy arrays.
45+
46+
This function takes a binary array representing a grid and splits it into two separate
47+
numpy arrays: one for the custom grid and one for the topography. The binary array is
48+
segmented based on the provided lengths for the custom grid and topography.
49+
50+
Args:
51+
binary_array: The binary data representing the combined custom grid and topography data.
52+
custom_grid_length: The length of the custom grid data segment in bytes.
53+
topography_length: The length of the topography data segment in bytes.
54+
55+
Returns:
56+
A tuple where the first element is a numpy array representing the custom grid, and
57+
the second element is a numpy array representing the topography data.
58+
59+
Raises:
60+
ValueError: If input lengths do not match the specified boundaries or binary data.
61+
"""
62+
63+
64+
65+
custom_grid_binary = binary_array[:custom_grid_length]
66+
topography_binary = binary_array[custom_grid_length:custom_grid_length + topography_length]
67+
custom_grid = np.frombuffer(custom_grid_binary, dtype=np.float64)
68+
topography = np.frombuffer(topography_binary)
69+
70+
71+
return custom_grid, topography

gempy/core/data/encoders/converters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ def instantiate_if_necessary(data: dict, key: str, type: type) -> None:
5252
loading_model_context = ContextVar('loading_model_context', default={})
5353

5454
@contextmanager
55-
def loading_model_from_binary(binary_body: bytes):
55+
def loading_model_from_binary(input_binary: bytes, grid_binary: bytes):
5656
token = loading_model_context.set({
57-
'binary_body': binary_body,
57+
'input_binary': input_binary,
58+
'grid_binary': grid_binary
5859
})
5960
try:
6061
yield

gempy/core/data/grid.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from gempy_engine.core.data.centered_grid import CenteredGrid
99
from gempy_engine.core.data.options import EvaluationOptions
1010
from gempy_engine.core.data.transforms import Transform
11+
from .encoders.binary_encoder import deserialize_grid
12+
from .encoders.converters import loading_model_context
1113
from .grid_modules import RegularGrid, CustomGrid, Sections
1214
from .grid_modules.topography import Topography
1315

@@ -59,13 +61,49 @@ def deserialize_properties(cls, data: Union["Grid", dict], constructor: ModelWra
5961
case dict():
6062
grid: Grid = constructor(data)
6163
grid._active_grids = Grid.GridTypes(data["active_grids"])
64+
# TODO: Digest binary data
65+
66+
metadata = data.get('binary_meta_data', {})
67+
context = loading_model_context.get()
68+
69+
if 'grid_binary' not in context:
70+
return grid
71+
72+
custom_grid_vals, topography_vals = deserialize_grid(
73+
binary_array=context['grid_binary'],
74+
custom_grid_length=metadata["custom_grid_binary_length"],
75+
topography_length=metadata["topography_binary_length"]
76+
)
77+
78+
if grid.custom_grid is not None:
79+
grid.custom_grid.values = custom_grid_vals.reshape(-1, 3)
80+
81+
if grid.topography is not None:
82+
grid.topography.set_values2d(values=topography_vals)
83+
6284
grid._update_values()
6385
return grid
6486
case _:
6587
raise ValidationError
6688
except ValidationError:
6789
raise
6890

91+
@property
92+
def grid_binary(self):
93+
custom_grid_bytes = self._custom_grid.values.astype("float64").tobytes() if self._custom_grid else b''
94+
topography_bytes = self._topography.values.astype("float64").tobytes() if self._topography else b''
95+
return custom_grid_bytes + topography_bytes
96+
97+
98+
_grid_binary_size: int = 0
99+
@computed_field
100+
def binary_meta_data(self) -> dict:
101+
return {
102+
'custom_grid_binary_length': len(self._custom_grid.values.astype("float64").tobytes()) if self._custom_grid else 0,
103+
'topography_binary_length': len(self._topography.values.astype("float64").tobytes()) if self._topography else 0,
104+
'grid_binary_size': self._grid_binary_size
105+
}
106+
69107
@computed_field(alias="active_grids")
70108
@property
71109
def active_grids(self) -> GridTypes:

gempy/core/data/grid_modules/topography.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Topography:
2929

3030
# Fields managed internally
3131
values: short_array_type = field(init=False, default=np.zeros((0, 3)))
32-
resolution: Tuple[int, int] = field(init=False, default=(0, 0))
32+
resolution: Tuple[int, int] = Field(init=True, default=(0, 0))
3333
raster_shape: Tuple[int, ...] = field(init=False, default=())
3434
_mask_topo: Optional[np.ndarray] = field(init=False, default=None, repr=False)
3535
_x: Optional[np.ndarray] = field(init=False, default=None, repr=False)
@@ -95,7 +95,7 @@ def from_unstructured_mesh(cls, regular_grid, xyz_vertices):
9595
# Reshape the grid for compatibility with existing structure
9696
values_2d = np.stack((x_regular, y_regular, z_regular), axis=-1)
9797

98-
return cls(regular_grid=regular_grid, values_2d=values_2d)
98+
return cls(_regular_grid=regular_grid, values_2d=values_2d)
9999

100100

101101
@classmethod
@@ -107,7 +107,7 @@ def from_arrays(cls, regular_grid, x_coordinates, y_coordinates, height_values,)
107107
topography_vals = height_values.values[:, :, np.newaxis] # shape (73, 34, 1)
108108
# Stack along the last dimension
109109
result = np.concatenate([x_vals, y_vals, topography_vals], axis=2) # shape (73, 34, 3)
110-
return cls(regular_grid=regular_grid, values_2d=result)
110+
return cls(_regular_grid=regular_grid, values_2d=result)
111111

112112
@property
113113
def extent(self):
@@ -155,6 +155,34 @@ def set_values(self, values_2d: np.ndarray):
155155
self.values = values_2d.reshape((-1, 3), order='C')
156156
return self
157157

158+
def set_values2d(self, values: np.ndarray) -> "Topography":
159+
"""
160+
Reconstruct the 2D topography (shape = resolution + [3]) from
161+
a flat Nx3 array (or from self.values if none is provided).
162+
"""
163+
# default to the already-flattened XYZ array
164+
165+
# compute expected size
166+
nx, ny = self.resolution
167+
expected = nx * ny * 3
168+
if values.size != expected:
169+
raise ValueError(
170+
f"Cannot reshape array of size {values.size} into shape {(nx, ny, 3)}."
171+
)
172+
173+
# reshape in C-order to (nx, ny, 3)
174+
self.set_values(
175+
values_2d=values.reshape(nx, ny, 3, order="C")
176+
)
177+
178+
# invalidate any cached mask
179+
self._mask_topo = None
180+
return self
181+
182+
def set_values2d_(self, values: np.ndarray):
183+
resolution = (60, 60)
184+
self.values_2d = values.reshape(*resolution, 3)
185+
158186
@property
159187
def topography_mask(self):
160188
"""This method takes a topography grid of the same extent as the regular

gempy/core/data/structural_frame.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -476,13 +476,14 @@ def deserialize_binary(cls, data: Union["StructuralFrame", dict], constructor: M
476476
metadata = data.get('binary_meta_data', {})
477477
context = loading_model_context.get()
478478

479-
if 'binary_body' not in context:
479+
if 'input_binary' not in context:
480480
return instance
481481

482482
instance.orientations, instance.surface_points = deserialize_input_data_tables(
483-
binary_array=context['binary_body'],
483+
binary_array=context['input_binary'],
484484
name_id_map=instance.surface_points_copy.name_id_map,
485-
sp_binary_length_=metadata["sp_binary_length"]
485+
sp_binary_length_=metadata["sp_binary_length"],
486+
ori_binary_length_=metadata["ori_binary_length"]
486487
)
487488

488489
return instance
@@ -491,12 +492,13 @@ def deserialize_binary(cls, data: Union["StructuralFrame", dict], constructor: M
491492

492493
# Access the context variable to get injected data
493494

494-
495+
_input_binary_size: int = 0
495496
@computed_field
496497
def binary_meta_data(self) -> dict:
497498
return {
498499
'sp_binary_length': len(self.surface_points_copy.data.tobytes()),
499-
# 'ori_binary_length': len(self.orientations_copy.data.tobytes()) * (miguel May 2025) This is not necessary at the moment
500+
'ori_binary_length': len(self.orientations_copy.data.tobytes()) ,
501+
'input_binary_size': self._input_binary_size
500502
}
501503

502504
# endregion

gempy/modules/serialization/save_load.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,7 @@ def save_model(model: GeoModel, path: str | None = None, validate_serialization:
4444
# If no extension, add the valid extension
4545
path = str(path_obj) + VALID_EXTENSION
4646

47-
model_json = model.model_dump_json(by_alias=True, indent=4)
48-
49-
# Compress the binary data
50-
zlib = require_zlib()
51-
compressed_binary = zlib.compress(model.structural_frame.input_tables_binary)
52-
53-
binary_file = _to_binary(model_json, compressed_binary)
47+
binary_file = model_to_binary(model)
5448

5549
if validate_serialization:
5650
model_deserialized = _deserialize_binary_file(binary_file)
@@ -67,6 +61,35 @@ def save_model(model: GeoModel, path: str | None = None, validate_serialization:
6761
return path # Return the actual path used (helpful if extension was added)
6862

6963

64+
def model_to_binary(model: GeoModel) -> bytes:
65+
66+
# Compress the binary data
67+
zlib = require_zlib()
68+
compressed_binary_input = zlib.compress(model.structural_frame.input_tables_binary)
69+
compressed_binary_grid = zlib.compress(model.grid.grid_binary)
70+
71+
compressed_binary_grid = zlib.compress(model.grid.grid_binary, level=6)
72+
73+
import hashlib
74+
print("len raw bytes:", len(model.grid.grid_binary))
75+
76+
print("raw bytes hash:", hashlib.sha256(model.grid.grid_binary).hexdigest())
77+
print("compressed length:", len(compressed_binary_grid))
78+
print("zlib version:", zlib.ZLIB_VERSION)
79+
80+
# * Add here the serialization meta parameters like: len_bytes
81+
model.structural_frame._input_binary_size = len(compressed_binary_input)
82+
model.grid._grid_binary_size = len(compressed_binary_grid)
83+
84+
model_json = model.model_dump_json(by_alias=True, indent=4)
85+
binary_file = _to_binary(
86+
header_json=model_json,
87+
body_input=compressed_binary_input,
88+
body_grid=compressed_binary_grid
89+
)
90+
return binary_file
91+
92+
7093
def load_model(path: str) -> GeoModel:
7194
"""
7295
Load a GeoModel from a file with extension validation.
@@ -110,39 +133,50 @@ def load_model(path: str) -> GeoModel:
110133

111134

112135
def _deserialize_binary_file(binary_file):
136+
import json
113137
# Get header length from first 4 bytes
114138
header_length = int.from_bytes(binary_file[:4], byteorder='little')
115139
# Split header and body
116-
header_json = binary_file[4:4 + header_length].decode('utf-8')
117-
binary_body = binary_file[4 + header_length:]
140+
header_json= binary_file[4:4 + header_length].decode('utf-8')
141+
header = json.loads(header_json)
142+
input_metadata = header["structural_frame"]["binary_meta_data"]
143+
input_size = input_metadata["input_binary_size"]
144+
145+
grid_metadata = header["grid"]["binary_meta_data"]
146+
grid_size = grid_metadata["grid_binary_size"]
147+
148+
input_binary = binary_file[4 + header_length: 4 + header_length + input_size]
149+
all_sections_length = 4 + header_length + input_size + grid_size
150+
if all_sections_length != len(binary_file):
151+
raise ValueError("Binary file is corrupted")
152+
153+
grid_binary = binary_file[4 + header_length + input_size: all_sections_length]
118154
zlib = require_zlib()
119-
decompressed_binary = zlib.decompress(binary_body)
155+
120156
with loading_model_from_binary(
121-
binary_body=decompressed_binary,
157+
input_binary=(zlib.decompress(input_binary)),
158+
grid_binary=(zlib.decompress(grid_binary))
122159
):
123160
model = GeoModel.model_validate_json(header_json)
124161
return model
125162

126163

127-
def _to_binary(header_json, body_) -> bytes:
164+
def _to_binary(header_json, body_input, body_grid) -> bytes:
128165
header_json_bytes = header_json.encode('utf-8')
129166
header_json_length = len(header_json_bytes)
130167
header_json_length_bytes = header_json_length.to_bytes(4, byteorder='little')
131-
file = header_json_length_bytes + header_json_bytes + body_
168+
file = header_json_length_bytes + header_json_bytes + body_input + body_grid
132169
return file
133170

134171

135172
def _validate_serialization(original_model, model_deserialized):
136-
if False:
137-
_verify_models(model_deserialized, original_model)
138-
139173
a = hash(original_model.structural_frame.surface_points_copy.data.tobytes())
140174
b = hash(model_deserialized.structural_frame.surface_points_copy.data.tobytes())
141175
o_a = hash(original_model.structural_frame.orientations_copy.data.tobytes())
142176
o_b = hash(model_deserialized.structural_frame.orientations_copy.data.tobytes())
143177
assert a == b, "Hashes for surface points are not equal"
144178
assert o_a == o_b, "Hashes for orientations are not equal"
145-
original_model___str__ = re.sub(r'\s+', ' ', original_model.__str__())
179+
original_model___str__ = re.sub(r'\s+', ' ', original_model.__str__())
146180
deserialized___str__ = re.sub(r'\s+', ' ', model_deserialized.__str__())
147181
if original_model___str__ != deserialized___str__:
148182
# Find first char that is not the same

requirements/dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
pytest
44
pyvista
55
subsurface-terra
6+
wellpathpy
67

78
# Testing
89
pytest-approvaltests

test/test_modules/_geophysics_TO_UPDATE/test_gravity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_gravity():
5959
structural_frame=frame,
6060
)
6161

62-
gp.compute_model(geo_model, validate_serialization=False)
62+
gp.compute_model(geo_model, validate_serialization=True)
6363

6464
import gempy_viewer as gpv
6565
gpv.plot_2d(geo_model, cell_number=0)

0 commit comments

Comments
 (0)