Skip to content

Commit 0211f96

Browse files
committed
[WIP] Deserialize binary seems to run
1 parent 122669c commit 0211f96

File tree

5 files changed

+125
-60
lines changed

5 files changed

+125
-60
lines changed

gempy/core/data/encoders/converters.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,14 @@ def loading_model_injection(surface_points_binary: np.ndarray, orientations_bina
5757
finally:
5858
loading_model_context.reset(token)
5959

60+
61+
@contextmanager
62+
def loading_model_from_binary(binary_body: bytes):
63+
token = loading_model_context.set({
64+
'binary_body': binary_body,
65+
})
66+
try:
67+
yield
68+
finally:
69+
loading_model_context.reset(token)
70+

gempy/core/data/geo_model.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -302,23 +302,20 @@ def add_surface_points(self, X: Sequence[float], Y: Sequence[float], Z: Sequence
302302
@model_validator(mode='wrap')
303303
@classmethod
304304
def deserialize_properties(cls, data: Union["GeoModel", dict], constructor: ModelWrapValidatorHandler["GeoModel"]) -> "GeoModel":
305-
try:
306-
match data:
307-
case GeoModel():
308-
return data
309-
case dict():
310-
instance: GeoModel = constructor(data)
311-
instantiate_if_necessary(
312-
data=data,
313-
key="_interpolation_options",
314-
type=InterpolationOptions
315-
)
316-
instance._interpolation_options = data.get("_interpolation_options")
317-
return instance
318-
case _:
319-
raise ValidationError
320-
except ValidationError:
321-
raise
305+
match data:
306+
case GeoModel():
307+
return data
308+
case dict():
309+
instance: GeoModel = constructor(data)
310+
instantiate_if_necessary(
311+
data=data,
312+
key="_interpolation_options",
313+
type=InterpolationOptions
314+
)
315+
instance._interpolation_options = data.get("_interpolation_options")
316+
return instance
317+
case _:
318+
raise ValidationError
322319

323320
# endregion
324321

gempy/core/data/structural_frame.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import numpy as np
44
import warnings
55
from dataclasses import dataclass
6-
from pydantic import model_validator, computed_field
7-
from typing import Generator
6+
from pydantic import model_validator, computed_field, ValidationError
7+
from pydantic.functional_validators import ModelWrapValidatorHandler
8+
from typing import Generator, Union
89

910
from gempy_engine.core.data.input_data_descriptor import InputDataDescriptor
1011
from gempy_engine.core.data.kernel_classes.faults import FaultsData
@@ -33,47 +34,101 @@ class StructuralFrame:
3334
# ? Should I create some sort of structural options class? For example, the masking descriptor and faults relations pointer
3435
is_dirty: bool = True
3536

37+
@model_validator(mode="wrap")
38+
@classmethod
39+
def deserialize_binary(cls, data: Union["StructuralFrame", dict], constructor: ModelWrapValidatorHandler["StructuralFrame"]) -> "StructuralFrame":
40+
match data:
41+
case StructuralFrame():
42+
return data
43+
case dict():
44+
instance: StructuralFrame = constructor(data)
45+
metadata = data.get('binary_meta_data', {})
46+
47+
context = loading_model_context.get()
48+
49+
if 'binary_body' not in context:
50+
return instance
51+
52+
binary_array = context['binary_body']
53+
54+
sp_binary = binary_array[:metadata["sp_binary_length"]]
55+
ori_binary = binary_array[metadata["sp_binary_length"]:]
56+
57+
# Reconstruct arrays
58+
sp_data: np.ndarray = np.frombuffer(sp_binary, dtype=SurfacePointsTable.dt)
59+
ori_data: np.ndarray = np.frombuffer(ori_binary, dtype=OrientationsTable.dt)
60+
61+
instance.surface_points = SurfacePointsTable(
62+
data=sp_data,
63+
name_id_map=instance.surface_points_copy.name_id_map
64+
)
65+
66+
instance.orientations = OrientationsTable(
67+
data=ori_data,
68+
name_id_map=instance.orientations_copy.name_id_map
69+
)
70+
71+
return instance
72+
case _:
73+
raise ValidationError(f"Invalid data type for StructuralFrame: {type(data)}")
74+
75+
# Access the context variable to get injected data
76+
3677
@model_validator(mode="after")
37-
def deserialize_surface_points(values: "StructuralFrame"):
78+
def deserialize_surface_points(self: "StructuralFrame"):
3879
# Access the context variable to get injected data
3980
context = loading_model_context.get()
4081

4182
if 'surface_points_binary' not in context:
42-
return values
83+
return self
4384

4485
# Check if we have a binary payload to digest
4586
binary_array = context['surface_points_binary']
4687
if not isinstance(binary_array, np.ndarray):
47-
return values
88+
return self
4889
if binary_array.shape[0] < 1:
49-
return values
50-
51-
values.surface_points = SurfacePointsTable(
90+
return self
91+
92+
self.surface_points = SurfacePointsTable(
5293
data=binary_array,
53-
name_id_map=values.surface_points_copy.name_id_map
94+
name_id_map=self.surface_points_copy.name_id_map
5495
)
55-
56-
return values
57-
96+
97+
return self
98+
5899
@model_validator(mode="after")
59-
def deserialize_orientations(values: "StructuralFrame"):
100+
def deserialize_orientations(self: "StructuralFrame"):
101+
# TODO: Check here the binary size of surface_points_binary
102+
60103
# Access the context variable to get injected data
61104
context = loading_model_context.get()
62105
if 'orientations_binary' not in context:
63-
return values
64-
106+
return self
107+
65108
# Check if we have a binary payload to digest
66109
binary_array = context['orientations_binary']
67110
if not isinstance(binary_array, np.ndarray):
68-
return values
69-
70-
values.orientations = OrientationsTable(
111+
return self
112+
113+
self.orientations = OrientationsTable(
71114
data=binary_array,
72-
name_id_map=values.orientations_copy.name_id_map
115+
name_id_map=self.orientations_copy.name_id_map
73116
)
74-
75-
return values
76-
117+
118+
return self
119+
120+
@computed_field
121+
def binary_meta_data(self) -> dict:
122+
sp_data = self.surface_points_copy.data
123+
ori_data = self.orientations_copy.data
124+
return {
125+
'sp_shape' : sp_data.shape,
126+
'sp_dtype' : str(sp_data.dtype),
127+
'sp_binary_length' : len(sp_data.tobytes()),
128+
'ori_shape' : ori_data.shape,
129+
'ori_dtype' : str(ori_data.dtype),
130+
'ori_binary_length': len(ori_data.tobytes())
131+
}
77132

78133
@computed_field
79134
@property

gempy/core/data/surface_points.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from dataclasses import dataclass
1+
import numpy as np
2+
from dataclasses import dataclass
23
from pydantic import field_validator, SkipValidation
3-
from typing import Optional, Union, Sequence, Annotated
4-
import numpy as np
4+
from typing import Optional, Union, Sequence
55

6-
from ._data_points_helpers import generate_ids_from_names
7-
from .encoders.converters import numpy_array_short_validator
86
from gempy_engine.core.data.transforms import Transform
9-
from gempy.optional_dependencies import require_pandas
7+
8+
from ...optional_dependencies import require_pandas
9+
from ._data_points_helpers import generate_ids_from_names
1010

1111
DEFAULT_SP_NUGGET = 0.00002
1212

gempy/modules/serialization/save_load.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from ...core.data import GeoModel
6-
from ...core.data.encoders.converters import loading_model_injection
6+
from ...core.data.encoders.converters import loading_model_injection, loading_model_from_binary
77
from ...optional_dependencies import require_zlib
88

99

@@ -22,15 +22,15 @@ def save_model(model: GeoModel, path: str):
2222
compressed_binary = zlib.compress(sp_binary + ori_binary)
2323

2424
# Add compression info to metadata
25-
model_dict = model.model_dump(by_alias=True)
26-
model_dict["_binary_metadata"] = {
27-
"sp_shape" : model.structural_frame.surface_points_copy.data.shape,
28-
"sp_dtype" : str(model.structural_frame.surface_points_copy.data.dtype),
29-
"ori_shape" : model.structural_frame.orientations_copy.data.shape,
30-
"ori_dtype" : str(model.structural_frame.orientations_copy.data.dtype),
31-
"compression": "zlib",
32-
"sp_length" : len(sp_binary) # Need this to split the arrays after decompression
33-
}
25+
# model_dict = model.model_dump(by_alias=True)
26+
# model_dict["_binary_metadata"] = {
27+
# "sp_shape" : model.structural_frame.surface_points_copy.data.shape,
28+
# "sp_dtype" : str(model.structural_frame.surface_points_copy.data.dtype),
29+
# "ori_shape" : model.structural_frame.orientations_copy.data.shape,
30+
# "ori_dtype" : str(model.structural_frame.orientations_copy.data.dtype),
31+
# "compression": "zlib",
32+
# "sp_length" : len(sp_binary) # Need this to split the arrays after decompression
33+
# }
3434

3535
# TODO: Putting both together
3636
binary_file = _to_binary(model_json, compressed_binary)
@@ -48,14 +48,16 @@ def load_model(path: str) -> GeoModel:
4848
header_json = binary_file[4:4 + header_length].decode('utf-8')
4949
header_dict = json.loads(header_json)
5050

51-
metadata = header_dict.pop("_binary_metadata")
51+
# metadata = header_dict.pop("_binary_metadata")
5252

5353
# Decompress the binary data
54-
ori_data, sp_data = _foo(binary_file, header_length, metadata)
54+
# ori_data, sp_data = _foo(binary_file, header_length, metadata)
5555

56-
with loading_model_injection(
57-
surface_points_binary=sp_data,
58-
orientations_binary=ori_data
56+
binary_body = binary_file[4 + header_length:]
57+
zlib = require_zlib()
58+
decompressed_binary = zlib.decompress(binary_body)
59+
with loading_model_from_binary(
60+
binary_body=decompressed_binary,
5961
):
6062
model = GeoModel.model_validate_json(header_json)
6163

0 commit comments

Comments
 (0)