@@ -20,7 +20,7 @@ def _to_path(value: Any) -> Optional[Path]:
2020 return Path (value ) if value else None
2121
2222
23- def _parse_shape (shape : str ) -> tuple [str ]:
23+ def _parse_dim_names (shape : str ) -> tuple [str ]:
2424 return tuple (
2525 [
2626 dim .strip ()
@@ -33,29 +33,36 @@ def _parse_shape(shape: str) -> tuple[str]:
3333 )
3434
3535
36- def _try_resolve_dim (data : DataTree , name : str ) -> int | str :
36+ def _try_resolve_dim (data : Optional [ DataTree ] , name : str ) -> int | str :
3737 name = name .strip ()
38+ if data is None :
39+ return name
3840 value = data .get (name , None )
3941 if value is not None :
4042 return value .item ()
4143 root = data .root
4244 paths = [
43- "/tdis" ,
44- "/gwf/dis" ,
45+ "tdis" ,
46+ "dis" ,
47+ "gwf/dis" ,
4548 ]
4649 for path in paths :
47- value = root .get (f"{ path } /{ name } " , None )
48- if value is not None :
49- return value .item ()
50- print (f"Failed to resolve dim '{ name } ' for '{ data .name } '" )
50+ try :
51+ key = f"{ path } /{ name } "
52+ return root [key ].item ()
53+ except :
54+ try :
55+ return root [path ].dims [name ]
56+ except :
57+ pass
5158 return name
5259
5360
5461def _try_resolve_shape (data : DataTree , attr : Attribute ) -> tuple [int | str ]:
5562 shape = attr .metadata .get ("shape" , None )
5663 if shape is None :
5764 raise ValueError (f"Array { attr .name } missing shape metadata" )
58- shape = [_try_resolve_dim (data , dim ) for dim in _parse_shape (shape )]
65+ shape = [_try_resolve_dim (data , dim ) for dim in _parse_dim_names (shape )]
5966 return shape
6067
6168
@@ -85,19 +92,18 @@ def _resolve_array(
8592 return _reshape_array (value , shape )
8693
8794
88- def _bind_tree (data : DataTree ):
89- if data .is_root :
90- return
91- data . parent = data . parent . assign ({ data . name : data } )
92- if not data . parent . is_root :
93- _bind_tree (data . parent )
95+ def _bind_tree (self , parent ):
96+ parent . data = parent . data .assign ({ self . data . name : self . data })
97+ self . data = parent . data [ self . data . name ]
98+ grandparent = getattr ( parent , "parent" , None )
99+ if grandparent is not None :
100+ _bind_tree (parent , grandparent )
94101
95102
96- def _init_tree (self , ** kwargs ):
103+ def _init_tree (self , parent = None , ** kwargs ):
97104 cls = type (self )
98105 cls_name = cls .__name__ .lower ()
99106 spec = fields_dict (cls )
100- parent = kwargs .get ("parent" , None )
101107 data = Dataset ()
102108 dims = set ()
103109
@@ -106,13 +112,15 @@ def _init_tree(self, **kwargs):
106112 value = kwargs .get (name , attr .default )
107113 shape = attr .metadata .get ("shape" , None )
108114 if shape is not None :
109- dim_names = [
110- _try_resolve_dim (parent , dim ) for dim in _parse_shape (shape )
115+ dim_names = _parse_dim_names (shape )
116+ shape = [
117+ _try_resolve_dim (parent .data .root if parent else None , dim )
118+ for dim in dim_names
111119 ]
112120 shape = tuple (
113121 [
114122 (dim if isinstance (dim , int ) else kwargs .get (dim , dim ))
115- for dim in dim_names
123+ for dim in shape
116124 ]
117125 )
118126 unresolved = [dim for dim in shape if not isinstance (dim , int )]
@@ -139,6 +147,9 @@ def _init_tree(self, **kwargs):
139147 data [name ] = value
140148
141149 self .data = DataTree (data , name = cls_name )
150+ if parent is not None :
151+ self .parent = parent
152+ _bind_tree (self , parent )
142153
143154
144155def _setattr (self , attr : Attribute , value : Any ):
@@ -150,7 +161,7 @@ def _setattr(self, attr: Attribute, value: Any):
150161 return
151162 self .data [attr .name ] = (
152163 (
153- _parse_shape (attr .metadata ["shape" ]),
164+ _parse_dim_names (attr .metadata ["shape" ]),
154165 _resolve_array (self , attr , value ),
155166 )
156167 if get_origin (attr .type ) in [list , np .ndarray ]
@@ -161,18 +172,17 @@ def _setattr(self, attr: Attribute, value: Any):
161172
162173def component (cls ):
163174 spec = fields_dict (cls )
164- init = cls .__init__
165-
166- def _init (self , * args , ** kwargs ):
167- init (self , * args , ** kwargs )
168- _bind_tree (self .data )
169175
170176 def _get (self , name ):
171177 if name in spec :
172- return self .data [name ]
178+ value = self .data .get (name , None )
179+ if value is not None :
180+ return value
181+ value = self .data .dims .get (name , None )
182+ if value is not None :
183+ return value
173184 return super (cls , self ).__getattribute__ (name )
174185
175- cls .__init__ = _init
176186 cls .__getattribute__ = _get
177187 return cls
178188
@@ -264,7 +274,7 @@ def __init__(
264274 ):
265275 _init_tree (
266276 self ,
267- parent = model . data ,
277+ parent = model ,
268278 length_units = length_units ,
269279 nogrb = nogrb ,
270280 xorigin = xorigin ,
@@ -308,7 +318,7 @@ def __init__(
308318 ):
309319 _init_tree (
310320 self ,
311- parent = model . data ,
321+ parent = model ,
312322 strt = strt ,
313323 export_array_ascii = export_array_ascii ,
314324 export_array_netcdf = export_array_netcdf ,
@@ -369,7 +379,7 @@ def __init__(
369379 ):
370380 _init_tree (
371381 self ,
372- parent = model . data ,
382+ parent = model ,
373383 budget_file = budget_file ,
374384 budget_csv_file = budget_csv_file ,
375385 head_file = head_file ,
@@ -437,7 +447,7 @@ def __init__(
437447 ):
438448 _init_tree (
439449 self ,
440- parent = model . data ,
450+ parent = model ,
441451 icelltype = icelltype ,
442452 k = k ,
443453 k22 = k22 ,
@@ -456,7 +466,7 @@ def __init__(
456466 self ,
457467 sim = None ,
458468 ):
459- _init_tree (self , parent = sim . data )
469+ _init_tree (self , parent = sim )
460470
461471
462472@component
@@ -490,7 +500,7 @@ def __init__(
490500 ):
491501 _init_tree (
492502 self ,
493- parent = sim . data ,
503+ parent = sim ,
494504 nper = nper ,
495505 perioddata = perioddata ,
496506 time_units = time_units ,
0 commit comments