Skip to content

Commit 122669c

Browse files
committed
[WIP] Improving body serialization
1 parent 64b4746 commit 122669c

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

gempy/modules/serialization/save_load.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,39 @@
1+
import json
2+
3+
import numpy as np
4+
15
from ...core.data import GeoModel
26
from ...core.data.encoders.converters import loading_model_injection
7+
from ...optional_dependencies import require_zlib
38

49

510
def save_model(model: GeoModel, path: str):
11+
import zlib
612

713
# TODO: Serialize to json
814
model_json = model.model_dump_json(by_alias=True, indent=4)
915

1016
# TODO: Serialize to binary
11-
sp_binary = model.structural_frame.surface_points_copy.data.tobytes()
17+
data: np.ndarray = model.structural_frame.surface_points_copy.data
18+
sp_binary = data.tobytes()
1219
ori_binary = model.structural_frame.orientations_copy.data.tobytes()
20+
21+
# Compress the binary data
22+
compressed_binary = zlib.compress(sp_binary + ori_binary)
23+
24+
# Add compression info to metadata
25+
model_dict = model.model_dump(by_alias=True)
26+
model_dict["_binary_metadata"] = {
27+
"sp_shape" : model.structural_frame.surface_points_copy.data.shape,
28+
"sp_dtype" : str(model.structural_frame.surface_points_copy.data.dtype),
29+
"ori_shape" : model.structural_frame.orientations_copy.data.shape,
30+
"ori_dtype" : str(model.structural_frame.orientations_copy.data.dtype),
31+
"compression": "zlib",
32+
"sp_length" : len(sp_binary) # Need this to split the arrays after decompression
33+
}
1334

1435
# TODO: Putting both together
15-
binary_file = _to_binary(model_json, sp_binary + ori_binary)
36+
binary_file = _to_binary(model_json, compressed_binary)
1637
with open(path, 'wb') as f:
1738
f.write(binary_file)
1839

@@ -25,22 +46,35 @@ def load_model(path: str) -> GeoModel:
2546

2647
# Split header and body
2748
header_json = binary_file[4:4 + header_length].decode('utf-8')
28-
body = binary_file[4 + header_length:]
49+
header_dict = json.loads(header_json)
2950

30-
# Split body into surface points and orientations
31-
# They are equal size so we can split in half
32-
sp_binary = body[:len(body) // 2]
33-
ori_binary = body[len(body) // 2:]
51+
metadata = header_dict.pop("_binary_metadata")
52+
53+
# Decompress the binary data
54+
ori_data, sp_data = _foo(binary_file, header_length, metadata)
3455

3556
with loading_model_injection(
36-
surface_points_binary=sp_binary,
37-
orientations_binary=ori_binary
57+
surface_points_binary=sp_data,
58+
orientations_binary=ori_data
3859
):
3960
model = GeoModel.model_validate_json(header_json)
4061

4162
return model
4263

4364

65+
def _foo(binary_file, header_length, metadata):
66+
zlib = require_zlib()
67+
body = binary_file[4 + header_length:]
68+
decompressed_binary = zlib.decompress(body)
69+
# Split the decompressed data using the stored length
70+
sp_binary = decompressed_binary[:metadata["sp_length"]]
71+
ori_binary = decompressed_binary[metadata["sp_length"]:]
72+
# Reconstruct arrays
73+
sp_data = np.frombuffer(sp_binary, dtype=np.dtype(metadata["sp_dtype"])).reshape(metadata["sp_shape"])
74+
ori_data = np.frombuffer(ori_binary, dtype=np.dtype(metadata["ori_dtype"])).reshape(metadata["ori_shape"])
75+
return ori_data, sp_data
76+
77+
4478
def _to_binary(header_json, body_) -> bytes:
4579
header_json_bytes = header_json.encode('utf-8')
4680
header_json_length = len(header_json_bytes)

gempy/optional_dependencies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,11 @@ def require_subsurface():
5858
import subsurface
5959
except ImportError:
6060
raise ImportError("The subsurface package is required to run this function.")
61-
return subsurface
61+
return subsurface
62+
63+
def require_zlib():
64+
try:
65+
import zlib
66+
except ImportError:
67+
raise ImportError("The zlib package is required to run this function.")
68+
return zlib

requirements/optional-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ gempy_plugins
55
# for data download
66
pooch
77
scipy
8-
scikit-image
8+
scikit-image
9+
zlib

test/test_modules/test_serialize_model.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,7 @@ def test_generate_horizontal_stratigraphic_model():
2626
):
2727
model_deserialized = gp.data.GeoModel.model_validate_json(model_json)
2828

29-
a = hash(model.structural_frame.structural_elements[1].surface_points.data.tobytes())
30-
b = hash(model_deserialized.structural_frame.structural_elements[1].surface_points.data.tobytes())
31-
32-
o_a = hash(model.structural_frame.structural_elements[1].orientations.data.tobytes())
33-
o_b = hash(model_deserialized.structural_frame.structural_elements[1].orientations.data.tobytes())
34-
35-
assert a == b, "Hashes for surface points are not equal"
36-
assert o_a == o_b, "Hashes for orientations are not equal"
37-
assert model_deserialized.__str__() == model.__str__()
29+
_validate_serialization(model, model_deserialized)
3830

3931
# # Validate json against schema
4032
if True:
@@ -43,15 +35,25 @@ def test_generate_horizontal_stratigraphic_model():
4335
verify_model = json.loads(model_json)
4436
verify_model["meta"]["creation_date"] = "<DATE_IGNORED>"
4537
verify_json(json.dumps(verify_model, indent=4), name="verify/Horizontal Stratigraphic Model serialization")
46-
38+
39+
40+
def _validate_serialization(original_model, model_deserialized):
41+
a = hash(original_model.structural_frame.structural_elements[1].surface_points.data.tobytes())
42+
b = hash(model_deserialized.structural_frame.structural_elements[1].surface_points.data.tobytes())
43+
o_a = hash(original_model.structural_frame.structural_elements[1].orientations.data.tobytes())
44+
o_b = hash(model_deserialized.structural_frame.structural_elements[1].orientations.data.tobytes())
45+
assert a == b, "Hashes for surface points are not equal"
46+
assert o_a == o_b, "Hashes for orientations are not equal"
47+
assert model_deserialized.__str__() == original_model.__str__()
48+
4749

4850
def test_save_model_to_disk():
4951
model = gp.generate_example_model(ExampleModel.COMBINATION, compute_model=False)
5052
save_model(model, "temp/test_save_model_to_disk.json")
5153

5254
# Load the model from disk
5355
loaded_model = load_model("temp/test_save_model_to_disk.json")
54-
assert loaded_model.__str__() == model.__str__()
56+
_validate_serialization(model, loaded_model)
5557

5658

5759

0 commit comments

Comments
 (0)