1- from collections .abc import Iterable , Mapping
1+ from collections .abc import Mapping
2+ from itertools import chain
23from pathlib import Path
34from typing import Annotated , Any , Optional , get_origin
45
@@ -215,6 +216,7 @@ def init_tree(
215216 name : Optional [str ] = None ,
216217 parent : Optional [_HasTree ] = None ,
217218 children : Optional [Mapping [str , _HasTree ]] = None ,
219+ ** kwargs ,
218220):
219221 """
220222 Initialize a data tree for a component class instance.
@@ -271,7 +273,7 @@ def _yield_arrays(spec, vals):
271273 var ,
272274 value = vals .pop (var .name , var .default ),
273275 tree = parent .data .root if parent else None ,
274- ** scalar_vals ,
276+ ** { ** scalar_vals , ** kwargs } ,
275277 )
276278 if val is not None :
277279 yield (var .name , (dims , val ))
@@ -347,11 +349,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
347349 self .data .update ({attr .name : value })
348350
349351
350- def component (
351- maybe_cls : Optional [type [_IsAttrs ]] = None ,
352- * ,
353- align : Optional [Iterable [str ]] = None ,
354- ) -> type [_Component ]:
352+ def component (maybe_cls : Optional [type [_IsAttrs ]] = None ) -> type [_Component ]:
355353 """
356354 Attach a data tree to an `attrs` class instance, and use
357355 the data tree for attribute storage: intercept gets/sets
@@ -365,26 +363,37 @@ def component(
365363
366364 def wrap (cls ):
367365 init_self = cls .__init__
366+ spec = fields_dict (cls )
368367
369368 def init (self , * args , ** kwargs ):
370369 name = kwargs .pop ("name" , None )
371370 children = kwargs .pop ("children" , None )
372371 parent = args [0 ] if args and any (args ) else None
373372
374373 # resolve dims from grid and time discretizations
374+ # get dims from spec
375+ dim_kwargs = {}
376+ dims_used = set (
377+ chain (* [var .metadata .get ("dims" , []) for var in spec .values ()])
378+ )
375379 grid : Grid = kwargs .pop ("grid" , None )
376380 time : ModelTime = kwargs .pop ("time" , None )
377- diss = [dis for dis in [grid , time ] if dis ]
378- if align :
379- for dim in align :
380- for dis in diss :
381- attr = getattr (dis , dim , None )
382- if attr is not None :
383- kwargs [dim ] = attr
381+ if grid :
382+ grid_dims = ["nlay" , "nrow" , "ncol" , "nnodes" ]
383+ for dim in grid_dims :
384+ if dim in dims_used :
385+ dim_kwargs [dim ] = getattr (grid , dim )
386+ if time :
387+ time_dims = ["nper" , "ntstp" ]
388+ for dim in time_dims :
389+ if dim in dims_used :
390+ dim_kwargs [dim ] = getattr (time , dim )
384391
385392 # run the original __init__, then set up the tree
386393 init_self (self , ** kwargs )
387- init_tree (self , name = name , parent = parent , children = children )
394+ init_tree (
395+ self , name = name , parent = parent , children = children , ** dim_kwargs
396+ )
388397
389398 # override attribute access
390399 cls .__getattr__ = getattribute
0 commit comments