Skip to content

Commit 36c6091

Browse files
committed
[ENH] Add serialization validation to save_model
Incorporate an optional validation step in `save_model` to ensure correct serialization and deserialization of GeoModels. Added `_validate_serialization` function and updated the workflow to compare model hashes and string representations.
1 parent 9b49116 commit 36c6091

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

gempy/modules/serialization/save_load.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66

77

8-
def save_model(model: GeoModel, path: str):
8+
def save_model(model: GeoModel, path: str, validate_serialization: bool = True):
99
"""
1010
Save a GeoModel to a file with proper extension validation.
1111
@@ -23,7 +23,7 @@ def save_model(model: GeoModel, path: str):
2323
"""
2424
# Define the valid extension for gempy models
2525
VALID_EXTENSION = ".gempy"
26-
26+
2727
# Check if path has an extension
2828
path_obj = pathlib.Path(path)
2929
if path_obj.suffix:
@@ -33,25 +33,30 @@ def save_model(model: GeoModel, path: str):
3333
else:
3434
# If no extension, add the valid extension
3535
path = str(path_obj) + VALID_EXTENSION
36-
36+
3737
model_json = model.model_dump_json(by_alias=True, indent=4)
3838

3939
# Compress the binary data
4040
zlib = require_zlib()
4141
compressed_binary = zlib.compress(model.structural_frame.input_tables_binary)
4242

4343
binary_file = _to_binary(model_json, compressed_binary)
44-
44+
45+
if validate_serialization:
46+
model_deserialized = _deserialize_binary_file(binary_file)
47+
_validate_serialization(model, model_deserialized)
48+
4549
# Create directory if it doesn't exist
4650
directory = os.path.dirname(path)
4751
if directory and not os.path.exists(directory):
4852
os.makedirs(directory)
49-
53+
5054
with open(path, 'wb') as f:
5155
f.write(binary_file)
52-
56+
5357
return path # Return the actual path used (helpful if extension was added)
54-
58+
59+
5560
def load_model(path: str) -> GeoModel:
5661
"""
5762
Load a GeoModel from a file with extension validation.
@@ -74,33 +79,34 @@ def load_model(path: str) -> GeoModel:
7479
If the file doesn't exist
7580
"""
7681
VALID_EXTENSION = ".gempy"
77-
82+
7883
# Check if path has the valid extension
7984
path_obj = pathlib.Path(path)
8085
if not path_obj.suffix or path_obj.suffix.lower() != VALID_EXTENSION:
8186
raise ValueError(f"Invalid file extension: {path_obj.suffix}. Expected: {VALID_EXTENSION}")
82-
87+
8388
# Check if file exists
8489
if not os.path.exists(path):
8590
raise FileNotFoundError(f"File not found: {path}")
86-
91+
8792
with open(path, 'rb') as f:
8893
binary_file = f.read()
8994

95+
return _deserialize_binary_file(binary_file)
96+
97+
98+
def _deserialize_binary_file(binary_file):
9099
# Get header length from first 4 bytes
91100
header_length = int.from_bytes(binary_file[:4], byteorder='little')
92-
93101
# Split header and body
94102
header_json = binary_file[4:4 + header_length].decode('utf-8')
95103
binary_body = binary_file[4 + header_length:]
96-
97104
zlib = require_zlib()
98105
decompressed_binary = zlib.decompress(binary_body)
99106
with loading_model_from_binary(
100-
binary_body=decompressed_binary,
107+
binary_body=decompressed_binary,
101108
):
102109
model = GeoModel.model_validate_json(header_json)
103-
104110
return model
105111

106112

@@ -109,4 +115,14 @@ def _to_binary(header_json, body_) -> bytes:
109115
header_json_length = len(header_json_bytes)
110116
header_json_length_bytes = header_json_length.to_bytes(4, byteorder='little')
111117
file = header_json_length_bytes + header_json_bytes + body_
112-
return file
118+
return file
119+
120+
121+
def _validate_serialization(original_model, model_deserialized):
122+
a = hash(original_model.structural_frame.surface_points_copy.data.tobytes())
123+
b = hash(model_deserialized.structural_frame.surface_points_copy.data.tobytes())
124+
o_a = hash(original_model.structural_frame.orientations_copy.data.tobytes())
125+
o_b = hash(model_deserialized.structural_frame.orientations_copy.data.tobytes())
126+
assert a == b, "Hashes for surface points are not equal"
127+
assert o_a == o_b, "Hashes for orientations are not equal"
128+
assert model_deserialized.__str__() == original_model.__str__()

0 commit comments

Comments
 (0)