44
55from os import PathLike
66from pathlib import Path
7- from typing import Any , Literal , Optional , get_origin
7+ from typing import Literal , Optional , get_origin
88from warnings import warn
99
10- import attrs
1110import numpy as np
1211from attr import define , field , fields_dict
1312from cattr import Converter
@@ -35,23 +34,28 @@ def _try_resolve_dim(self, name) -> int | str:
3534 return name
3635
3736
38- def _to_array (value : ArrayLike ) -> Optional [NDArray ]:
37+ def _try_resolve_shape (self , field ) -> tuple [int | str ]:
38+ dim_names = _parse_dim_names (field .metadata ["shape" ])
39+ return tuple ([_try_resolve_dim (self , n ) for n in dim_names ])
40+
41+
42+ def _to_array (value : Optional [ArrayLike ]) -> Optional [NDArray ]:
3943 return None if value is None else np .array (value )
4044
4145
4246def _to_shaped_array (
43- value : ArrayLike | str | PathLike , self_ , field
47+ value : Optional [ ArrayLike | str | PathLike ] , self_ , field
4448) -> Optional [NDArray ]:
4549 if isinstance (value , (str , PathLike )):
46- # TODO
50+ # TODO handle external arrays
4751 pass
4852
4953 value = _to_array (value )
5054 if value is None :
5155 return None
52- dim_names = _parse_dim_names ( field . metadata [ "shape" ])
53- shape = tuple ([ _try_resolve_dim ( self_ , n ) for n in dim_names ] )
54- unresolved = [d for d in shape if not isinstance (d , int )]
56+
57+ shape = _try_resolve_shape ( self_ , field )
58+ unresolved = [dim for dim in shape if not isinstance (dim , int )]
5559 if any (unresolved ):
5660 warn (f"Failed to resolve dimension names: { ', ' .join (unresolved )} " )
5761 return value
@@ -69,20 +73,10 @@ def _to_path(value) -> Optional[Path]:
6973
7074
7175def datatree (cls ):
72- # TODO
73- # - determine whether data array, data set, or data tree DONE
74- # - shape check arrays (dynamic validator?)
75- # check for parent and update dimensions
76- # then try to realign existing packages?
77-
78- old_post_init = getattr (cls , "__attrs_post_init__" , None )
79-
80- def __attrs_post_init__ (self ):
81- print (f"Running datatree on { cls .__name__ } " )
82-
83- if old_post_init :
84- old_post_init (self )
76+ post_init_name = "__attrs_post_init__"
77+ post_init_prev = getattr (cls , post_init_name , None )
8578
79+ def _set_data_on_self (self , cls ):
8680 fields = fields_dict (cls )
8781 arrays = {}
8882 for n , f in fields .items ():
@@ -91,59 +85,71 @@ def __attrs_post_init__(self):
9185 value = getattr (self , n )
9286 if value is None :
9387 continue
94- arrays [n ] = (_parse_dim_names (f .metadata ["shape" ]), value )
88+ arrays [n ] = (
89+ _parse_dim_names (f .metadata ["shape" ]),
90+ _to_shaped_array (value , self , f ),
91+ )
9592 dataset = Dataset (arrays )
96- children = getattr (self , "children" , None )
97- if children :
98- self .data = DataTree (
99- dataset , name = cls .__name__ , children = [c .data for c in children ]
93+ self .data = (
94+ DataTree (dataset , name = cls .__name__ .lower ()[3 :])
95+ if issubclass (cls , Model )
96+ else dataset
97+ )
98+
99+ def _set_self_on_model (self , cls ):
100+ model = getattr (self , "model" , None )
101+ if model :
102+ self_name = cls .__name__ .lower ()[3 :]
103+ setattr (model , self_name , self )
104+ model .data = model .data .assign (
105+ {self_name : DataTree (self .data , name = self_name )}
100106 )
101- else :
102- self .data = dataset
103107
104- cls .__attrs_post_init__ = __attrs_post_init__
108+ def __attrs_post_init__ (self ):
109+ if post_init_prev :
110+ post_init_prev (self )
111+
112+ _set_data_on_self (self , cls )
113+ _set_self_on_model (self , cls )
105114
115+ # TODO: figure out why classes need to have a
116+ # __attrs_post_init__ method for this to work
117+ setattr (cls , post_init_name , __attrs_post_init__ )
106118 return cls
107119
108120
121+ class Model :
122+ pass
123+
124+
109125@datatree
110126@define (slots = False )
111127class GwfDis :
112128 nlay : int = field (default = 1 , metadata = {"block" : "dimensions" })
113129 ncol : int = field (default = 2 , metadata = {"block" : "dimensions" })
114130 nrow : int = field (default = 2 , metadata = {"block" : "dimensions" })
115131 delr : NDArray [np .floating ] = field (
116- converter = attrs .Converter (
117- _to_shaped_array , takes_self = True , takes_field = True
118- ),
132+ converter = _to_array ,
119133 default = 1.0 ,
120134 metadata = {"block" : "griddata" , "shape" : "(ncol,)" },
121135 )
122136 delc : NDArray [np .floating ] = field (
123- converter = attrs .Converter (
124- _to_shaped_array , takes_self = True , takes_field = True
125- ),
137+ converter = _to_array ,
126138 default = 1.0 ,
127139 metadata = {"block" : "griddata" , "shape" : "(nrow,)" },
128140 )
129141 top : NDArray [np .floating ] = field (
130- converter = attrs .Converter (
131- _to_shaped_array , takes_self = True , takes_field = True
132- ),
142+ converter = _to_array ,
133143 default = 1.0 ,
134144 metadata = {"block" : "griddata" , "shape" : "(ncol, nrow)" },
135145 )
136146 botm : NDArray [np .floating ] = field (
137- converter = attrs .Converter (
138- _to_shaped_array , takes_self = True , takes_field = True
139- ),
147+ converter = _to_array ,
140148 default = 0.0 ,
141149 metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" },
142150 )
143151 idomain : Optional [NDArray [np .integer ]] = field (
144- converter = attrs .Converter (
145- _to_shaped_array , takes_self = True , takes_field = True
146- ),
152+ converter = _to_array ,
147153 default = 1 ,
148154 metadata = {"block" : "griddata" , "shape" : "(ncol, nrow, nlay)" },
149155 )
@@ -156,8 +162,7 @@ class GwfDis:
156162 default = False , metadata = {"block" : "options" }
157163 )
158164 nodes : int = field (init = False )
159- data : Dataset = field (init = False )
160- model : Optional [Any ] = field (default = None )
165+ model : Optional [Model ] = field (default = None )
161166
162167 def __attrs_post_init__ (self ):
163168 self .nodes = self .nlay * self .ncol * self .nrow
@@ -167,9 +172,8 @@ def __attrs_post_init__(self):
167172@define (slots = False )
168173class GwfIc :
169174 strt : NDArray [np .floating ] = field (
170- converter = attrs .Converter (
171- _to_shaped_array , takes_self = True , takes_field = True
172- ),
175+ converter = _to_array ,
176+ default = 1.0 ,
173177 metadata = {"block" : "packagedata" , "shape" : "(nodes)" },
174178 )
175179 export_array_ascii : bool = field (
@@ -179,8 +183,11 @@ class GwfIc:
179183 default = False ,
180184 metadata = {"block" : "options" },
181185 )
182- data : Dataset = field (init = False )
183- model : Optional [Any ] = field (default = None )
186+ model : Optional [Model ] = field (default = None )
187+
188+ def __attrs_post_init__ (self ):
189+ # for some reason this is necessary..
190+ pass
184191
185192
186193@datatree
@@ -208,17 +215,23 @@ class Format:
208215 perioddata : Optional [list [list [tuple ]]] = field (
209216 default = None , metadata = {"block" : "perioddata" }
210217 )
211- data : Dataset = field (init = False )
212- model : Optional [Any ] = field (default = None )
218+ model : Optional [Model ] = field (default = None )
219+
220+ def __attrs_post_init__ (self ):
221+ # for some reason this is necessary..
222+ pass
213223
214224
215225@datatree
216226@define (slots = False )
217- class Gwf :
227+ class Gwf ( Model ) :
218228 dis : Optional [GwfDis ] = field (default = None )
219229 ic : Optional [GwfIc ] = field (default = None )
220230 oc : Optional [GwfOc ] = field (default = None )
221- data : DataTree = field (init = False )
231+
232+ def __attrs_post_init__ (self ):
233+ # for some reason this is necessary..
234+ pass
222235
223236
224237# We can define a package with some data.
@@ -268,7 +281,12 @@ class Gwf:
268281assert period [0 ] == ("print" , "budget" , "steps" , 1 , 3 , 5 )
269282
270283
271- # Creating a model by constructor.
284+ # Create a model.
285+
272286
287+ gwf = Gwf ()
288+ dis = GwfDis (model = gwf )
289+ ic = GwfIc (model = gwf , strt = 1 )
290+ oc .model = gwf
273291
274- gwf = Gwf ( dis = GwfDis (), ic = GwfIc ( strt = 1 ), oc = oc )
292+ # View the data tree.
0 commit comments