Skip to content

Commit 9b49116

Browse files
committed
[ENH] Add file extension validation and directory creation to save/load
Ensure correct `.gempy` extension when saving/loading models, raising errors for invalid or missing extensions. Added automatic creation of directories during save operations. Updated tests to reflect the enforced `.gempy` extension.
1 parent 0308f5a commit 9b49116

File tree

2 files changed

+68
-7
lines changed

2 files changed

+68
-7
lines changed

gempy/modules/serialization/save_load.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
11
from ...core.data import GeoModel
22
from ...core.data.encoders.converters import loading_model_from_binary
33
from ...optional_dependencies import require_zlib
4+
import pathlib
5+
import os
46

57

68
def save_model(model: GeoModel, path: str):
9+
"""
10+
Save a GeoModel to a file with proper extension validation.
11+
12+
Parameters:
13+
-----------
14+
model : GeoModel
15+
The geological model to save
16+
path : str
17+
The file path where to save the model
18+
19+
Raises:
20+
-------
21+
ValueError
22+
If the file has an extension other than .gempy
23+
"""
24+
# Define the valid extension for gempy models
25+
VALID_EXTENSION = ".gempy"
26+
27+
# Check if path has an extension
28+
path_obj = pathlib.Path(path)
29+
if path_obj.suffix:
30+
# If extension exists but is not valid, raise error
31+
if path_obj.suffix.lower() != VALID_EXTENSION:
32+
raise ValueError(f"Invalid file extension: {path_obj.suffix}. Expected: {VALID_EXTENSION}")
33+
else:
34+
# If no extension, add the valid extension
35+
path = str(path_obj) + VALID_EXTENSION
736

837
model_json = model.model_dump_json(by_alias=True, indent=4)
938

@@ -13,14 +42,48 @@ def save_model(model: GeoModel, path: str):
1342

1443
binary_file = _to_binary(model_json, compressed_binary)
1544

16-
# TODO: Add validation
45+
# Create directory if it doesn't exist
46+
directory = os.path.dirname(path)
47+
if directory and not os.path.exists(directory):
48+
os.makedirs(directory)
1749

1850
with open(path, 'wb') as f:
1951
f.write(binary_file)
20-
2152

53+
return path # Return the actual path used (helpful if extension was added)
2254

2355
def load_model(path: str) -> GeoModel:
56+
"""
57+
Load a GeoModel from a file with extension validation.
58+
59+
Parameters:
60+
-----------
61+
path : str
62+
Path to the gempy model file
63+
64+
Returns:
65+
--------
66+
GeoModel
67+
The loaded geological model
68+
69+
Raises:
70+
-------
71+
ValueError
72+
If the file doesn't have the proper .gempy extension
73+
FileNotFoundError
74+
If the file doesn't exist
75+
"""
76+
VALID_EXTENSION = ".gempy"
77+
78+
# Check if path has the valid extension
79+
path_obj = pathlib.Path(path)
80+
if not path_obj.suffix or path_obj.suffix.lower() != VALID_EXTENSION:
81+
raise ValueError(f"Invalid file extension: {path_obj.suffix}. Expected: {VALID_EXTENSION}")
82+
83+
# Check if file exists
84+
if not os.path.exists(path):
85+
raise FileNotFoundError(f"File not found: {path}")
86+
2487
with open(path, 'rb') as f:
2588
binary_file = f.read()
2689

@@ -46,6 +109,4 @@ def _to_binary(header_json, body_) -> bytes:
46109
header_json_length = len(header_json_bytes)
47110
header_json_length_bytes = header_json_length.to_bytes(4, byteorder='little')
48111
file = header_json_length_bytes + header_json_bytes + body_
49-
return file
50-
51-
112+
return file

test/test_modules/test_serialize_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def _validate_serialization(original_model, model_deserialized):
4949

5050
def test_save_model_to_disk():
5151
model = gp.generate_example_model(ExampleModel.COMBINATION, compute_model=False)
52-
save_model(model, "temp/test_save_model_to_disk.json")
52+
save_model(model, "temp/test_save_model_to_disk.gempy")
5353

5454
# Load the model from disk
55-
loaded_model = load_model("temp/test_save_model_to_disk.json")
55+
loaded_model = load_model("temp/test_save_model_to_disk.gempy")
5656
_validate_serialization(model, loaded_model)
5757

5858
gp.compute_model(loaded_model)

0 commit comments

Comments
 (0)