1
1
from ...core .data import GeoModel
2
2
from ...core .data .encoders .converters import loading_model_from_binary
3
3
from ...optional_dependencies import require_zlib
4
+ import pathlib
5
+ import os
4
6
5
7
6
8
def save_model (model : GeoModel , path : str ):
9
+ """
10
+ Save a GeoModel to a file with proper extension validation.
11
+
12
+ Parameters:
13
+ -----------
14
+ model : GeoModel
15
+ The geological model to save
16
+ path : str
17
+ The file path where to save the model
18
+
19
+ Raises:
20
+ -------
21
+ ValueError
22
+ If the file has an extension other than .gempy
23
+ """
24
+ # Define the valid extension for gempy models
25
+ VALID_EXTENSION = ".gempy"
26
+
27
+ # Check if path has an extension
28
+ path_obj = pathlib .Path (path )
29
+ if path_obj .suffix :
30
+ # If extension exists but is not valid, raise error
31
+ if path_obj .suffix .lower () != VALID_EXTENSION :
32
+ raise ValueError (f"Invalid file extension: { path_obj .suffix } . Expected: { VALID_EXTENSION } " )
33
+ else :
34
+ # If no extension, add the valid extension
35
+ path = str (path_obj ) + VALID_EXTENSION
7
36
8
37
model_json = model .model_dump_json (by_alias = True , indent = 4 )
9
38
@@ -13,14 +42,48 @@ def save_model(model: GeoModel, path: str):
13
42
14
43
binary_file = _to_binary (model_json , compressed_binary )
15
44
16
- # TODO: Add validation
45
+ # Create directory if it doesn't exist
46
+ directory = os .path .dirname (path )
47
+ if directory and not os .path .exists (directory ):
48
+ os .makedirs (directory )
17
49
18
50
with open (path , 'wb' ) as f :
19
51
f .write (binary_file )
20
-
21
52
53
+ return path # Return the actual path used (helpful if extension was added)
22
54
23
55
def load_model (path : str ) -> GeoModel :
56
+ """
57
+ Load a GeoModel from a file with extension validation.
58
+
59
+ Parameters:
60
+ -----------
61
+ path : str
62
+ Path to the gempy model file
63
+
64
+ Returns:
65
+ --------
66
+ GeoModel
67
+ The loaded geological model
68
+
69
+ Raises:
70
+ -------
71
+ ValueError
72
+ If the file doesn't have the proper .gempy extension
73
+ FileNotFoundError
74
+ If the file doesn't exist
75
+ """
76
+ VALID_EXTENSION = ".gempy"
77
+
78
+ # Check if path has the valid extension
79
+ path_obj = pathlib .Path (path )
80
+ if not path_obj .suffix or path_obj .suffix .lower () != VALID_EXTENSION :
81
+ raise ValueError (f"Invalid file extension: { path_obj .suffix } . Expected: { VALID_EXTENSION } " )
82
+
83
+ # Check if file exists
84
+ if not os .path .exists (path ):
85
+ raise FileNotFoundError (f"File not found: { path } " )
86
+
24
87
with open (path , 'rb' ) as f :
25
88
binary_file = f .read ()
26
89
@@ -46,6 +109,4 @@ def _to_binary(header_json, body_) -> bytes:
46
109
header_json_length = len (header_json_bytes )
47
110
header_json_length_bytes = header_json_length .to_bytes (4 , byteorder = 'little' )
48
111
file = header_json_length_bytes + header_json_bytes + body_
49
- return file
50
-
51
-
112
+ return file
0 commit comments