@@ -271,15 +271,15 @@ def __post_init__(self) -> None:
271271 # Register the tree_map here instead of `__init_subclass__` as `jax` may
272272 # not have been imported yet during import.
273273 if enp .lazy .has_jax and not cls ._dca_jax_tree_registered : # pylint: disable=protected-access
274- enp .lazy .jax .tree_util .register_pytree_node_class (cls )
274+ enp .lazy .jax .tree_util .register_pytree_with_keys_class (cls )
275275 cls ._dca_jax_tree_registered = True # pylint: disable=protected-access
276276
277277 if enp .lazy .has_torch and not cls ._dca_torch_tree_registered : # pylint: disable=protected-access
278278 # Note: Torch is updating it's tree API to make it public and use `optree`
279279 # as backend: https://github.com/pytorch/pytorch/issues/65761
280280 enp .lazy .torch .utils ._pytree ._register_pytree_node ( # pylint: disable=protected-access
281281 cls ,
282- flatten_fn = lambda a : a .tree_flatten (),
282+ flatten_fn = lambda a : a ._tree_flatten (), # pylint: disable=protected-access
283283 unflatten_fn = lambda vals , ctx : cls .tree_unflatten (ctx , vals ),
284284 )
285285 cls ._dca_torch_tree_registered = True # pylint: disable=protected-access
@@ -295,7 +295,7 @@ def __post_init__(self) -> None:
295295
296296 # Validate the batch shape is consistent
297297 # However, we need to be careful that `_ArrayField` never uses
298- # `@epy .cached_property`
298+ # `@functools .cached_property`
299299 shape = self ._broadcast_shape_inplace ()
300300
301301 # TODO(epot): When to validate (`field.validate()`)
@@ -581,7 +581,7 @@ def cuda(self: _DcT, *args, **kwargs) -> _DcT:
581581
582582 # ====== Internal ======
583583
584- @epy .cached_property
584+ @functools .cached_property
585585 def _all_fields_empty (self ) -> bool :
586586 """Returns True if the `dataclass_array` is invalid."""
587587 if not self ._array_fields : # All fields are `None` / `object`
@@ -601,9 +601,10 @@ def _all_fields_empty(self) -> bool:
601601 return True
602602 return False
603603
604- @epy .cached_property
604+ @functools .cached_property
605605 def _all_array_fields (self ) -> dict [str , _ArrayField ]:
606606 """All array fields, including `None` values."""
607+ assert self ._dca_fields_metadata is not None
607608 return { # pylint: disable=g-complex-comprehension
608609 name : _ArrayField (
609610 name = name ,
@@ -613,7 +614,7 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
613614 for name , field_metadata in self ._dca_fields_metadata .items () # pylint: disable=protected-access
614615 }
615616
616- @epy .cached_property
617+ @functools .cached_property
617618 def _array_fields (self ) -> list [_ArrayField ]:
618619 """All active array fields (non-None), including static ones."""
619620 # Filter `None` values
@@ -767,10 +768,19 @@ def _apply_field_dn(f: _ArrayField):
767768 else :
768769 return self
769770
770- def tree_flatten (self ) -> tuple [tuple [DcOrArray , ...], _TreeMetadata ]:
771+ def tree_flatten_with_keys (
772+ self ,
773+ ) -> tuple [
774+ tuple [tuple [enp .lazy .jax .tree_utils .BuiltInKeyEntry , DcOrArray ], ...],
775+ _TreeMetadata ,
776+ ]:
771777 """`jax.tree_utils` support."""
778+ tree_util = enp .lazy .jax .tree_util
772779 # We flatten all values (and not just the non-None ones)
773- array_field_values = tuple (f .value for f in self ._all_array_fields .values ())
780+ array_field_values = tuple (
781+ (tree_util .GetAttrKey (k ), f .value )
782+ for k , f in self ._all_array_fields .items ()
783+ )
774784 metadata = _TreeMetadata (
775785 array_field_names = list (self ._all_array_fields .keys ()),
776786 non_array_field_kwargs = {
@@ -817,8 +827,13 @@ def tree_unflatten(
817827 self ._setattr (k , v ) # pylint: disable=protected-access
818828 return self
819829
830+ def _tree_flatten (self ) -> tuple [tuple [DcOrArray , ...], _TreeMetadata ]:
831+ components , metadata = self .tree_flatten_with_keys ()
832+ components = tuple (v for _ , v in components )
833+ return components , metadata
834+
820835 def __tf_flatten__ (self ) -> tuple [_TreeMetadata , tuple [DcOrArray , ...]]:
821- components , metadata = self .tree_flatten ()
836+ components , metadata = self ._tree_flatten ()
822837 return metadata , components
823838
824839 @classmethod
@@ -1070,7 +1085,7 @@ def full_shape(self) -> DcOrArrayT:
10701085 # empty shapes to True `bool(shape) == True` when `shape=()`
10711086 return tuple (self .value .shape )
10721087
1073- @epy .cached_property
1088+ @functools .cached_property
10741089 def inner_shape (self ) -> Shape :
10751090 """Returns the the static shape resolved for the current value."""
10761091 if not self .inner_shape_non_static :
0 commit comments