5
5
import os
6
6
7
7
8
- def save_model (model : GeoModel , path : str ):
8
+ def save_model (model : GeoModel , path : str , validate_serialization : bool = True ):
9
9
"""
10
10
Save a GeoModel to a file with proper extension validation.
11
11
@@ -23,7 +23,7 @@ def save_model(model: GeoModel, path: str):
23
23
"""
24
24
# Define the valid extension for gempy models
25
25
VALID_EXTENSION = ".gempy"
26
-
26
+
27
27
# Check if path has an extension
28
28
path_obj = pathlib .Path (path )
29
29
if path_obj .suffix :
@@ -33,25 +33,30 @@ def save_model(model: GeoModel, path: str):
33
33
else :
34
34
# If no extension, add the valid extension
35
35
path = str (path_obj ) + VALID_EXTENSION
36
-
36
+
37
37
model_json = model .model_dump_json (by_alias = True , indent = 4 )
38
38
39
39
# Compress the binary data
40
40
zlib = require_zlib ()
41
41
compressed_binary = zlib .compress (model .structural_frame .input_tables_binary )
42
42
43
43
binary_file = _to_binary (model_json , compressed_binary )
44
-
44
+
45
+ if validate_serialization :
46
+ model_deserialized = _deserialize_binary_file (binary_file )
47
+ _validate_serialization (model , model_deserialized )
48
+
45
49
# Create directory if it doesn't exist
46
50
directory = os .path .dirname (path )
47
51
if directory and not os .path .exists (directory ):
48
52
os .makedirs (directory )
49
-
53
+
50
54
with open (path , 'wb' ) as f :
51
55
f .write (binary_file )
52
-
56
+
53
57
return path # Return the actual path used (helpful if extension was added)
54
-
58
+
59
+
55
60
def load_model (path : str ) -> GeoModel :
56
61
"""
57
62
Load a GeoModel from a file with extension validation.
@@ -74,33 +79,34 @@ def load_model(path: str) -> GeoModel:
74
79
If the file doesn't exist
75
80
"""
76
81
VALID_EXTENSION = ".gempy"
77
-
82
+
78
83
# Check if path has the valid extension
79
84
path_obj = pathlib .Path (path )
80
85
if not path_obj .suffix or path_obj .suffix .lower () != VALID_EXTENSION :
81
86
raise ValueError (f"Invalid file extension: { path_obj .suffix } . Expected: { VALID_EXTENSION } " )
82
-
87
+
83
88
# Check if file exists
84
89
if not os .path .exists (path ):
85
90
raise FileNotFoundError (f"File not found: { path } " )
86
-
91
+
87
92
with open (path , 'rb' ) as f :
88
93
binary_file = f .read ()
89
94
95
+ return _deserialize_binary_file (binary_file )
96
+
97
+
98
+ def _deserialize_binary_file (binary_file ):
90
99
# Get header length from first 4 bytes
91
100
header_length = int .from_bytes (binary_file [:4 ], byteorder = 'little' )
92
-
93
101
# Split header and body
94
102
header_json = binary_file [4 :4 + header_length ].decode ('utf-8' )
95
103
binary_body = binary_file [4 + header_length :]
96
-
97
104
zlib = require_zlib ()
98
105
decompressed_binary = zlib .decompress (binary_body )
99
106
with loading_model_from_binary (
100
- binary_body = decompressed_binary ,
107
+ binary_body = decompressed_binary ,
101
108
):
102
109
model = GeoModel .model_validate_json (header_json )
103
-
104
110
return model
105
111
106
112
@@ -109,4 +115,14 @@ def _to_binary(header_json, body_) -> bytes:
109
115
header_json_length = len (header_json_bytes )
110
116
header_json_length_bytes = header_json_length .to_bytes (4 , byteorder = 'little' )
111
117
file = header_json_length_bytes + header_json_bytes + body_
112
- return file
118
+ return file
119
+
120
+
121
+ def _validate_serialization (original_model , model_deserialized ):
122
+ a = hash (original_model .structural_frame .surface_points_copy .data .tobytes ())
123
+ b = hash (model_deserialized .structural_frame .surface_points_copy .data .tobytes ())
124
+ o_a = hash (original_model .structural_frame .orientations_copy .data .tobytes ())
125
+ o_b = hash (model_deserialized .structural_frame .orientations_copy .data .tobytes ())
126
+ assert a == b , "Hashes for surface points are not equal"
127
+ assert o_a == o_b , "Hashes for orientations are not equal"
128
+ assert model_deserialized .__str__ () == original_model .__str__ ()
0 commit comments