1818Scalar = bool | int | float | str | Path
1919"""A scalar value."""
2020
21- _HasAttrs = Annotated [object , Is [lambda obj : attrs .has (type (obj ))]]
21+ _IsAttrs = Annotated [object , Is [lambda obj : attrs .has (type (obj ))]]
2222"""Runtime-applied type hint for `attrs` based class instances."""
2323
24- _HasData = Annotated [object , IsAttr ["data" , IsInstance [DataTree ]]]
24+ _HasTree = Annotated [object , IsAttr ["data" , IsInstance [DataTree ]]]
2525"""Runtime-applied type hint for objects with a `DataTree` in `.data`."""
2626
2727_Component = Annotated [
@@ -92,7 +92,7 @@ def _find_recursive(tree, key):
9292
9393
9494def resolve_array (
95- self : _HasAttrs ,
95+ self : _IsAttrs ,
9696 attr : Attribute ,
9797 value : ArrayLike ,
9898 tree : DataTree = None ,
@@ -120,7 +120,7 @@ def resolve_array(
120120 shape = [find (tree or DataTree (), key = dim , default = dim ) for dim in dims ]
121121 shape = tuple (
122122 [
123- (dim if isinstance (dim , int ) else kwargs .get (dim , dim ))
123+ (dim if isinstance (dim , int ) else kwargs .pop (dim , dim ))
124124 for dim in shape
125125 ]
126126 )
@@ -140,30 +140,58 @@ def resolve_array(
140140 return value
141141
142142
143- def bind_tree (self : _HasData , parent : _HasData ):
143+ def bind_tree (
144+ self : _Component ,
145+ parent : _Component = None ,
146+ children : Optional [Mapping [str , _Component ]] = None ,
147+ ):
144148 """
145- Bind a child component to a parent, linking their trees.
146- If the parent isn't the root, rebind it to recursively
147- upwards to the root.
149+ Bind a given component to a parent component, linking the
150+ two components and their data trees. If the parent is not
151+ the tree's root, rebind it to recursively up to the root.
152+
153+ Also attach any child components to the given component's
154+ data tree, as well as to any non-`attrs` attributes whose
155+ name matches a child's name.
148156
149157 TODO: this is massively duplicative, since each component
150158 has a subtree of its own, next to the one its parent owns
151159 and in which its tree appears. need to have a single tree
152160 at the root, then each component's data is a view into it.
153161 """
154- parent .data = parent .data .assign ({self .data .name : self .data })
155- self .data = parent .data [self .data .name ]
156- grandparent = getattr (parent , "parent" , None )
157- if grandparent is not None :
158- bind_tree (parent , grandparent )
159- self .parent = parent
162+
163+ cls = type (self )
164+
165+ if parent :
166+ parent_spec = fields_dict (type (parent ))
167+ if self .data .name in parent_spec :
168+ setattr (parent , self .data .name , self )
169+
170+ # TODO
171+ # parent_bindings = {
172+ # k: v
173+ # for k, v in parent_spec.items()
174+ # if v.metadata.get("bind", False)
175+ # }
176+
177+ parent .data = parent .data .assign ({self .data .name : self .data })
178+ self .data = parent .data [self .data .name ]
179+ grandparent = getattr (parent , "parent" , None )
180+ if grandparent is not None :
181+ bind_tree (parent , grandparent )
182+ self .parent = parent
183+ self .children = children
184+ spec = fields_dict (type (self ))
185+ for n , c in (children or {}).items ():
186+ if n in spec :
187+ setattr (self , n , c )
160188
161189
162190def init_tree (
163- self : _HasAttrs ,
191+ self : _IsAttrs ,
164192 name : Optional [str ] = None ,
165- parent : Optional [_HasData ] = None ,
166- children : Optional [Mapping [str , _HasData ]] = None ,
193+ parent : Optional [_HasTree ] = None ,
194+ children : Optional [Mapping [str , _HasTree ]] = None ,
167195):
168196 """
169197 Initialize a data tree for a component class instance.
@@ -180,40 +208,43 @@ class cannot use slots for this to work.
180208 spec = fields_dict (cls )
181209 data = Dataset ()
182210 dims = set ()
211+ arrays = {}
212+ scalars = {}
213+ children = children or {}
183214
184- # set arrays, then scalars. filter array dims out
185- # on the first pass thru, while we set up arrays,
215+ # set scalars and arrays. filter array dims out
186216 # so they're not attached as both vars and dims.
217+ # also filter out subcomponents, just want vars.
187218 for attr in spec .values ():
219+ bind = attr .metadata .get ("bind" , False )
220+ if bind :
221+ continue
188222 dims_ = attr .metadata .get ("dims" , None )
189223 if dims_ is None :
224+ scalars [attr .name ] = attr
190225 continue
191226 dims .update (dims_ )
227+ arrays [attr .name ] = attr
228+ scalars = {k : self .__dict__ .pop (k , v .default ) for k , v in scalars .items ()}
229+ for attr in arrays .values ():
230+ dims_ = attr .metadata ["dims" ]
192231 value = resolve_array (
193232 self ,
194233 attr ,
195- value = self .__dict__ .pop (attr .name ),
234+ value = self .__dict__ .pop (attr .name , attr . default ),
196235 tree = parent .data .root if parent else None ,
197- ** self . __dict__ ,
236+ ** scalars ,
198237 )
199238 data [attr .name ] = (dims_ , value )
200- for attr in spec .values ():
201- if attr .name in data or attr .name in dims :
202- continue
203- data [attr .name ] = self .__dict__ .pop (attr .name , attr .default )
239+ for k , v in scalars .items ():
240+ data .attrs [k ] = v
204241
205- # create tree
206242 self .data = DataTree (
207243 data ,
208244 name = name or cls .__name__ .lower (),
209- children = {
210- n : c .data for n , c in (children or {}).items () if c is not None
211- },
245+ children = {n : c .data for n , c in children .items ()},
212246 )
213-
214- # bind tree
215- if parent is not None :
216- bind_tree (self , parent )
247+ bind_tree (self , parent = parent , children = children )
217248
218249
219250def getattribute (self : Any , name : str ) -> Any :
@@ -230,8 +261,11 @@ def getattribute(self: Any, name: str) -> Any:
230261 """
231262 cls = type (self )
232263 spec = fields_dict (cls )
264+ if name == "data" :
265+ raise AttributeError
233266 tree = self .data
234- if name in spec :
267+ var = spec .get (name , None )
268+ if var :
235269 value = get (tree , name , None )
236270 if value is not None :
237271 return value
@@ -263,7 +297,7 @@ def setattribute(self: _Component, attr: Attribute, value: Any):
263297 # TODO run validation?
264298
265299
266- def component (cls : type [_HasAttrs ]) -> type [_Component ]:
300+ def component (cls : type [_IsAttrs ]) -> type [_Component ]:
267301 """
268302 Attach a data tree to an `attrs` class instance, and use
269303 the data tree for attribute storage: intercept gets/sets
@@ -277,13 +311,13 @@ def component(cls: type[_HasAttrs]) -> type[_Component]:
277311
278312 old_init = cls .__init__
279313
280- def init (self , * args , ** kwargs ):
314+ def _init (self , * args , ** kwargs ):
281315 name = kwargs .pop ("name" , None )
282316 parent = args [0 ] if args and any (args ) else None
283317 children = kwargs .pop ("children" , None )
284318 old_init (self , ** kwargs )
285319 init_tree (self , name = name , parent = parent , children = children )
286320 cls .__getattr__ = getattribute
287321
288- cls .__init__ = init
322+ cls .__init__ = _init
289323 return cls
0 commit comments