22
33# This example demonstrates a tentative `attrs`-based object model.
44
5+ from os import PathLike
56from pathlib import Path
6- from typing import List , Literal , Optional
7+ from typing import Any , Literal , Optional , get_origin
8+ from warnings import warn
79
10+ import attrs
811import numpy as np
9- from attr import asdict , define , field
12+ from attr import define , field , fields_dict
1013from cattr import Converter
11- from flopy .discretization import StructuredGrid
12- from numpy .typing import NDArray
14+ from numpy .typing import ArrayLike , NDArray
1315from xarray import Dataset , DataTree
1416
1517
16- @define
18+ def _parse_dim_names (s : str ) -> tuple [str ]:
19+ return tuple (
20+ [
21+ ss .strip ()
22+ for ss in s .strip ().replace ("(" , "" ).replace (")" , "" ).split ("," )
23+ if any (ss )
24+ ]
25+ )
26+
27+
28+ def _try_resolve_dim (self , name ) -> int | str :
29+ name = name .strip ()
30+ value = getattr (self , name , None )
31+ if value :
32+ return value
33+ if hasattr (self , "model" ) and hasattr (self .model , "dis" ):
34+ return getattr (self .model .dis , name , name )
35+ return name
36+
37+
38+ def _to_array (value : ArrayLike ) -> Optional [NDArray ]:
39+ return None if value is None else np .array (value )
40+
41+
42+ def _to_shaped_array (
43+ value : ArrayLike | str | PathLike , self_ , field
44+ ) -> Optional [NDArray ]:
45+ if isinstance (value , (str , PathLike )):
46+ # TODO
47+ pass
48+
49+ value = _to_array (value )
50+ if value is None :
51+ return None
52+ dim_names = _parse_dim_names (field .metadata ["shape" ])
53+ shape = tuple ([_try_resolve_dim (self_ , n ) for n in dim_names ])
54+ unresolved = [d for d in shape if not isinstance (d , int )]
55+ if any (unresolved ):
56+ warn (f"Failed to resolve dimension names: { ', ' .join (unresolved )} " )
57+ return value
58+ elif value .shape == ():
59+ return np .ones (shape ) ** value .item ()
60+ elif value .shape != shape :
61+ raise ValueError (
62+ f"Shape mismatch, got { value .shape } , expected { shape } "
63+ )
64+ return value
65+
66+
67+ def _to_path (value ) -> Optional [Path ]:
68+ return Path (value ) if value else None
69+
70+
71+ def datatree (cls ):
72+ # TODO
73+ # - determine whether data array, data set, or data tree DONE
74+ # - shape check arrays (dynamic validator?)
75+ # check for parent and update dimensions
76+ # then try to realign existing packages?
77+
78+ old_post_init = getattr (cls , "__attrs_post_init__" , None )
79+
80+ def __attrs_post_init__ (self ):
81+ print (f"Running datatree on { cls .__name__ } " )
82+
83+ if old_post_init :
84+ old_post_init (self )
85+
86+ fields = fields_dict (cls )
87+ arrays = {}
88+ for n , f in fields .items ():
89+ if get_origin (f .type ) is not np .ndarray :
90+ continue
91+ value = getattr (self , n )
92+ if value is None :
93+ continue
94+ arrays [n ] = (_parse_dim_names (f .metadata ["shape" ]), value )
95+ dataset = Dataset (arrays )
96+ children = getattr (self , "children" , None )
97+ if children :
98+ self .data = DataTree (
99+ dataset , name = cls .__name__ , children = [c .data for c in children ]
100+ )
101+ else :
102+ self .data = dataset
103+
104+ cls .__attrs_post_init__ = __attrs_post_init__
105+
106+ return cls
107+
108+
109+ @datatree
110+ @define (slots = False )
111+ class GwfDis :
112+ nlay : int = field (default = 1 , metadata = {"block" : "dimensions" })
113+ ncol : int = field (default = 2 , metadata = {"block" : "dimensions" })
114+ nrow : int = field (default = 2 , metadata = {"block" : "dimensions" })
115+ delr : NDArray [np .floating ] = field (
116+ converter = attrs .Converter (
117+ _to_shaped_array , takes_self = True , takes_field = True
118+ ),
119+ default = 1.0 ,
120+ metadata = {"block" : "griddata" , "shape" : "(ncol,)" },
121+ )
122+ delc : NDArray [np .floating ] = field (
123+ converter = attrs .Converter (
124+ _to_shaped_array , takes_self = True , takes_field = True
125+ ),
126+ default = 1.0 ,
127+ metadata = {"block" : "griddata" , "shape" : "(nrow,)" },
128+ )
129+ top : NDArray [np .floating ] = field (
130+ converter = attrs .Converter (
131+ _to_shaped_array , takes_self = True , takes_field = True
132+ ),
133+ default = 1.0 ,
134+ metadata = {"block" : "griddata" , "shape" : "(ncol, nrow)" },
135+ )
136+ botm : NDArray [np .floating ] = field (
137+ converter = attrs .Converter (
138+ _to_shaped_array , takes_self = True , takes_field = True
139+ ),
140+ default = 0.0 ,
141+ metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" },
142+ )
143+ idomain : Optional [NDArray [np .integer ]] = field (
144+ converter = attrs .Converter (
145+ _to_shaped_array , takes_self = True , takes_field = True
146+ ),
147+ default = 1 ,
148+ metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" },
149+ )
150+ length_units : str = field (default = None , metadata = {"block" : "options" })
151+ nogrb : bool = field (default = False , metadata = {"block" : "options" })
152+ xorigin : float = field (default = None , metadata = {"block" : "options" })
153+ yorigin : float = field (default = None , metadata = {"block" : "options" })
154+ angrot : float = field (default = None , metadata = {"block" : "options" })
155+ export_array_netcdf : bool = field (
156+ default = False , metadata = {"block" : "options" }
157+ )
158+ nodes : int = field (init = False )
159+ data : Dataset = field (init = False )
160+ model : Optional [Any ] = field (default = None )
161+
162+ def __attrs_post_init__ (self ):
163+ self .nodes = self .nlay * self .ncol * self .nrow
164+
165+
166+ @datatree
167+ @define (slots = False )
17168class GwfIc :
18- strt : NDArray [np .float64 ] = field (
19- metadata = {"block" : "packagedata" , "shape" : "(nodes)" }
169+ strt : NDArray [np .floating ] = field (
170+ converter = attrs .Converter (
171+ _to_shaped_array , takes_self = True , takes_field = True
172+ ),
173+ metadata = {"block" : "packagedata" , "shape" : "(nodes)" },
20174 )
21175 export_array_ascii : bool = field (
22176 default = False , metadata = {"block" : "options" }
@@ -25,13 +179,12 @@ class GwfIc:
25179 default = False ,
26180 metadata = {"block" : "options" },
27181 )
182+ data : Dataset = field (init = False )
183+ model : Optional [Any ] = field (default = None )
28184
29- def __attrs_post_init__ (self ):
30- # TODO: setup attributes for blocks?
31- self .data = DataTree (Dataset ({"strt" : self .strt }), name = "ic" )
32185
33-
34- @define
186+ @ datatree
187+ @define ( slots = False )
35188class GwfOc :
36189 @define
37190 class Format :
@@ -40,96 +193,42 @@ class Format:
40193 digits : int
41194 format : Literal ["exponential" , "fixed" , "general" , "scientific" ]
42195
43- periods : List [List [tuple ]] = field (metadata = {"block" : "perioddata" })
44196 budget_file : Optional [Path ] = field (
45- default = None , metadata = {"block" : "options" }
197+ converter = _to_path , default = None , metadata = {"block" : "options" }
46198 )
47199 budget_csv_file : Optional [Path ] = field (
48- default = None , metadata = {"block" : "options" }
200+ converter = _to_path , default = None , metadata = {"block" : "options" }
49201 )
50202 head_file : Optional [Path ] = field (
51- default = None , metadata = {"block" : "options" }
203+ converter = _to_path , default = None , metadata = {"block" : "options" }
52204 )
53205 printhead : Optional [Format ] = field (
54206 default = None , metadata = {"block" : "options" }
55207 )
56-
57-
58- @define
59- class GwfDis :
60- nlay : int = field (metadata = {"block" : "dimensions" })
61- ncol : int = field (metadata = {"block" : "dimensions" })
62- nrow : int = field (metadata = {"block" : "dimensions" })
63- delr : NDArray [np .float64 ] = field (
64- metadata = {"block" : "griddata" , "shape" : "(ncol,)" }
65- )
66- delc : NDArray [np .float64 ] = field (
67- metadata = {"block" : "griddata" , "shape" : "(nrow,)" }
68- )
69- top : NDArray [np .float64 ] = field (
70- metadata = {"block" : "griddata" , "shape" : "(ncol, nrow)" }
71- )
72- botm : NDArray [np .float64 ] = field (
73- metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" }
74- )
75- idomain : NDArray [np .float64 ] = field (
76- metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" }
77- )
78- length_units : str = field (default = None , metadata = {"block" : "options" })
79- nogrb : bool = field (default = False , metadata = {"block" : "options" })
80- xorigin : float = field (default = None , metadata = {"block" : "options" })
81- yorigin : float = field (default = None , metadata = {"block" : "options" })
82- angrot : float = field (default = None , metadata = {"block" : "options" })
83- export_array_netcdf : bool = field (
84- default = False , metadata = {"block" : "options" }
208+ perioddata : Optional [list [list [tuple ]]] = field (
209+ default = None , metadata = {"block" : "perioddata" }
85210 )
86-
87- def __attrs_post_init__ (self ):
88- self .data = DataTree (
89- Dataset (
90- {
91- "nlay" : self .nlay ,
92- "ncol" : self .ncol ,
93- "nrow" : self .nrow ,
94- "delr" : self .delr ,
95- "delc" : self .delc ,
96- "top" : self .top ,
97- "botm" : self .botm ,
98- "idomain" : self .idomain ,
99- }
100- ),
101- name = "dis" ,
102- )
103- # TODO: check for parent and update dimensions
104- # then try to realign any existing packages?
211+ data : Dataset = field (init = False )
212+ model : Optional [Any ] = field (default = None )
105213
106214
107- @define
215+ @datatree
216+ @define (slots = False )
108217class Gwf :
109- dis : GwfDis = field ()
110- ic : GwfIc = field ()
111-
112- def __attrs_post_init__ (self ):
113- self .data = DataTree .from_dict (
114- {"/dis" : self .dis , "/ic" : self .ic }, name = "gwf"
115- )
116- self .grid = StructuredGrid (** asdict (self .dis ))
117-
118- @ic .validator
119- def _check_dims (self , attribute , value ):
120- assert value .strt .shape == (
121- self .dis .nlay * self .dis .nrow * self .dis .ncol
122- )
218+ dis : Optional [GwfDis ] = field (default = None )
219+ ic : Optional [GwfIc ] = field (default = None )
220+ oc : Optional [GwfOc ] = field (default = None )
221+ data : DataTree = field (init = False )
123222
124223
125224# We can define a package with some data.
126225
127226
128227oc = GwfOc (
129228 budget_file = "some/file/path.cbc" ,
130- periods = [[("print" , "budget" , "steps" , 1 , 3 , 5 )]],
229+ perioddata = [[("print" , "budget" , "steps" , 1 , 3 , 5 )]],
131230)
132- assert isinstance (oc .budget_file , str ) # TODO path
231+ assert isinstance (oc .budget_file , Path )
133232
134233
135234# We now set up a `cattrs` converter to convert an unstructured
@@ -142,7 +241,7 @@ def _check_dims(self, attribute, value):
142241# as would be returned by a separate IO layer in the future.
143242# (Either hand-written or using e.g. lark.)
144243
145- gwfoc = converter .structure (
244+ oc = converter .structure (
146245 {
147246 "budget_file" : "some/file/path.cbc" ,
148247 "head_file" : "some/file/path.hds" ,
@@ -152,7 +251,7 @@ def _check_dims(self, attribute, value):
152251 "digits" : 8 ,
153252 "format" : "scientific" ,
154253 },
155- "periods " : [
254+ "perioddata " : [
156255 [
157256 ("print" , "budget" , "steps" , 1 , 3 , 5 ),
158257 ("save" , "head" , "frequency" , 2 ),
@@ -161,9 +260,15 @@ def _check_dims(self, attribute, value):
161260 },
162261 GwfOc ,
163262)
164- assert gwfoc .budget_file == Path ("some/file/path.cbc" )
165- assert gwfoc .printhead .width == 10
166- assert gwfoc .printhead .format == "scientific"
167- period = gwfoc . periods [0 ]
263+ assert oc .budget_file == Path ("some/file/path.cbc" )
264+ assert oc .printhead .width == 10
265+ assert oc .printhead .format == "scientific"
266+ period = oc . perioddata [0 ]
168267assert len (period ) == 2
169268assert period [0 ] == ("print" , "budget" , "steps" , 1 , 3 , 5 )
269+
270+
271+ # Creating a model by constructor.
272+
273+
274+ gwf = Gwf (dis = GwfDis (), ic = GwfIc (strt = 1 ), oc = oc )
0 commit comments