11from collections .abc import Mapping
2- from itertools import chain
32from pathlib import Path
43from typing import Annotated , Any , Optional , get_origin
54
@@ -97,7 +96,7 @@ def resolve_array(
9796 tree : DataTree = None ,
9897 strict : bool = False ,
9998 ** kwargs ,
100- ) -> Optional [NDArray ]:
99+ ) -> tuple [ Optional [NDArray ], Optional [ dict [ str , NDArray ]] ]:
101100 """
102101 Resolve an array-like value to the given variable's expected shape.
103102 If the value is a collection, check if the shape matches. If scalar,
@@ -118,15 +117,15 @@ def resolve_array(
118117 f"Component class '{ type (self ).__name__ } ' array "
119118 f"variable '{ attr .name } ' could not be resolved "
120119 )
121- return None
120+ return None , None
122121 dims = attr .metadata .get ("dims" , None )
123122 if not dims :
124123 if strict :
125124 raise ValueError (
126125 f"Component class '{ type (self ).__name__ } ' array "
127126 f"variable '{ attr .name } ' needs 'dims' metadata"
128127 )
129- return None
128+ return None , None
130129 shape = [find (tree or DataTree (), key = dim , default = dim ) for dim in dims ]
131130 shape = tuple (
132131 [
@@ -142,8 +141,10 @@ def resolve_array(
142141 f"variable '{ attr .name } ' failed dim resolution: "
143142 f"{ ', ' .join (unresolved )} "
144143 )
145- return None
146- return reshape_array (value , shape )
144+ return None , None
145+ array = reshape_array (value , shape )
146+ coords = {dim : np .arange (size ) for dim , size in zip (dims , shape )}
147+ return array , coords
147148
148149
149150def bind_tree (
@@ -160,10 +161,11 @@ def bind_tree(
160161 data tree, as well as to any non-`attrs` attributes whose
161162 name matches a child's name.
162163
163- TODO: this is massively duplicative, since each component
164- has a subtree of its own, next to the one its parent owns
165- and in which its tree appears. need to have a single tree
166- at the root, then each component's data is a view into it.
164+ TODO: discover dimensions from self, parent and children.
165+ If the parent defines a dimension, it should be used for
166+ self and children. If a dimension is found in self or in
167+ a child which has scope broader than self, send it up to
168+ the parent.
167169 """
168170
169171 cls = type (self )
@@ -172,30 +174,27 @@ def bind_tree(
172174
173175 # bind parent
174176 if parent :
175- # try binding first by name
177+ # bind to parent attrs whose name
178+ # matches this component's name
176179 parent_spec = fields_dict (type (parent ))
177180 parent_var = parent_spec .get (name , None )
178181 if parent_var :
179182 assert parent_var .metadata .get ("bind" , False )
180183 setattr (parent , name , self )
181- # TODO bind multipackages by type
182- # parent_bindings = {
183- # n: v
184- # for n, v in parent_spec.items()
185- # if v.metadata.get("bind", False)
186- # }
187- # print(parent_bindings)
184+
185+ # bind parent data tree
188186 if name in parent .data :
189187 parent .data .update ({name : self .data })
190188 else :
191189 parent .data = parent .data .assign ({name : self .data })
192190 self .data = parent .data [self .data .name ]
193191
194- # bind grandparent
192+ # bind grandparent recursively
195193 grandparent = getattr (parent , "parent" , None )
196194 if grandparent is not None :
197195 bind_tree (parent , parent = grandparent )
198196
197+ # update parent reference
199198 self .parent = parent
200199
201200 # bind children
@@ -234,11 +233,12 @@ def init_tree(
234233 cls = type (self )
235234 spec = fields_dict (cls )
236235 dimensions = set ()
236+ coordinates = {}
237+ components = {}
237238 array_vars = {}
238239 scalar_vars = {}
239240 array_vals = {}
240241 scalar_vals = {}
241- components = {}
242242
243243 for var in spec .values ():
244244 bind = var .metadata .get ("bind" , False )
@@ -264,21 +264,23 @@ def _yield_scalars(spec, vals):
264264 def _yield_arrays (spec , vals ):
265265 for var in spec .values ():
266266 dims = var .metadata ["dims" ]
267- val = resolve_array (
267+ val , coords = resolve_array (
268268 self ,
269269 var ,
270270 value = vals .pop (var .name , var .default ),
271271 tree = parent .data .root if parent else None ,
272272 ** {** scalar_vals , ** kwargs },
273273 )
274274 if val is not None :
275+ coordinates .update (coords )
275276 yield (var .name , (dims , val ))
276277
277278 array_vals = dict (list (_yield_arrays (spec = array_vars , vals = self .__dict__ )))
278279
279280 self .data = DataTree (
280281 Dataset (
281282 data_vars = array_vals ,
283+ coords = coordinates ,
282284 attrs = {
283285 n : v for n , v in scalar_vals .items () if n not in dimensions
284286 },
@@ -332,7 +334,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
332334 return value
333335 if get_origin (attr .type ) in [list , np .ndarray ]:
334336 shape = attr .metadata ["dims" ]
335- value = resolve_array (self , attr , value )
337+ value , _ = resolve_array (self , attr , value )
336338 value = (shape , value )
337339 bind = attr .metadata .get ("bind" , False )
338340 if bind :
@@ -341,6 +343,30 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
341343 self .data .update ({attr .name : value })
342344
343345
346+ def pop_dims (** kwargs ):
347+ """
348+ Use dims from `Grid` and/or `ModelTime` instances
349+ passed to `grid` and `time` keyword arguments, if
350+ available.
351+ """
352+ dims = {}
353+ grid : Grid = kwargs .pop ("grid" , None )
354+ time : ModelTime = kwargs .pop ("time" , None )
355+ grid_dims = ["nlay" , "nrow" , "ncol" , "nnodes" ]
356+ time_dims = ["nper" , "nstp" ]
357+ if grid :
358+ for dim in grid_dims :
359+ dims [dim ] = getattr (grid , dim )
360+ if time :
361+ for dim in time_dims :
362+ dims [dim ] = getattr (time , dim )
363+ for dim in grid_dims + time_dims :
364+ v = kwargs .pop (dim , None )
365+ if v is not None :
366+ dims [dim ] = v
367+ return kwargs , dims
368+
369+
344370def component (maybe_cls : Optional [type [_IsAttrs ]] = None ) -> type [_Component ]:
345371 """
346372 Attach a data tree to an `attrs` class instance, and use
@@ -362,28 +388,11 @@ def init(self, *args, **kwargs):
362388 children = kwargs .pop ("children" , None )
363389 parent = args [0 ] if args and any (args ) else None
364390
365- # use dims from grid and modeltime, if provided
366- dim_kwargs = {}
367- dims_used = set (
368- chain (* [var .metadata .get ("dims" , []) for var in spec .values ()])
369- )
370- grid : Grid = kwargs .pop ("grid" , None )
371- time : ModelTime = kwargs .pop ("time" , None )
372- if grid :
373- grid_dims = ["nlay" , "nrow" , "ncol" , "nnodes" ]
374- for dim in grid_dims :
375- if dim in dims_used :
376- dim_kwargs [dim ] = getattr (grid , dim )
377- if time :
378- time_dims = ["nper" , "ntstp" ]
379- for dim in time_dims :
380- if dim in dims_used :
381- dim_kwargs [dim ] = getattr (time , dim )
382-
383391 # run the original __init__, then set up the tree
392+ kwargs , dimensions = pop_dims (** kwargs )
384393 init_self (self , ** kwargs )
385394 init_tree (
386- self , name = name , parent = parent , children = children , ** dim_kwargs
395+ self , name = name , parent = parent , children = children , ** dimensions
387396 )
388397 bind_tree (self , parent = parent , children = children )
389398
@@ -397,3 +406,7 @@ def init(self, *args, **kwargs):
397406 return wrap
398407
399408 return wrap (maybe_cls )
409+
410+
411+ # TODO: add separate `component()` decorator like `attrs.field()`?
412+ # for now, "bind" metadata indicates subcomponent, not a variable.
0 commit comments