Skip to content

Commit b6598ff

Browse files
committed
[ENH] Adding logic to test serialization on compute model
1 parent 2be0238 commit b6598ff

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

gempy/API/compute_API.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from typing import Optional
1+
import dotenv
2+
import os
3+
4+
from typing import Optional
25

36
import numpy as np
47

@@ -14,6 +17,8 @@
1417
from ..modules.data_manipulation.engine_factory import interpolation_input_from_structural_frame
1518
from ..optional_dependencies import require_gempy_legacy
1619

20+
dotenv.load_dotenv()
21+
1722

1823
def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConfig] = None) -> Solutions:
1924
"""
@@ -56,6 +61,12 @@ def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConf
5661
case _:
5762
raise ValueError(f'Backend {engine_config} not supported')
5863

64+
if os.getenv("VALIDATE_SERIALIZATION", False):
65+
from ..modules.serialization.save_load import save_model
66+
import tempfile
67+
with tempfile.NamedTemporaryFile(mode='w+', delete=True) as tmp:
68+
save_model(model=gempy_model, path=tmp.name, validate_serialization=True)
69+
5970
return gempy_model.solutions
6071

6172

gempy/modules/serialization/save_load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66

77

8-
def save_model(model: GeoModel, path: str, validate_serialization: bool = True):
8+
def save_model(model: GeoModel, path: str | None = None, validate_serialization: bool = True):
99
"""
1010
Save a GeoModel to a file with proper extension validation.
1111
@@ -23,6 +23,8 @@ def save_model(model: GeoModel, path: str, validate_serialization: bool = True):
2323
"""
2424
# Define the valid extension for gempy models
2525
VALID_EXTENSION = ".gempy"
26+
if path is None:
27+
path = model.meta.name + VALID_EXTENSION
2628

2729
# Check if path has an extension
2830
path_obj = pathlib.Path(path)

0 commit comments

Comments
 (0)