Skip to content

Commit 4f60128

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Fix dataclass array flat repr
PiperOrigin-RevId: 595656508
1 parent 0bd50d8 commit 4f60128

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2424
## [Unreleased]
2525

2626
* Fix `etree.spec_like`, `jax.ShapeDtypeStruct`,... support
27+
* Changed `jax.tree_util` keep paths (for better flatten repr)
2728

2829
## [1.5.1] - 2023-08-30
2930

dataclass_array/array_dataclass.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,15 @@ def __post_init__(self) -> None:
271271
# Register the tree_map here instead of `__init_subclass__` as `jax` may
272272
# not have been imported yet during import.
273273
if enp.lazy.has_jax and not cls._dca_jax_tree_registered: # pylint: disable=protected-access
274-
enp.lazy.jax.tree_util.register_pytree_node_class(cls)
274+
enp.lazy.jax.tree_util.register_pytree_with_keys_class(cls)
275275
cls._dca_jax_tree_registered = True # pylint: disable=protected-access
276276

277277
if enp.lazy.has_torch and not cls._dca_torch_tree_registered: # pylint: disable=protected-access
278278
# Note: Torch is updating it's tree API to make it public and use `optree`
279279
# as backend: https://github.com/pytorch/pytorch/issues/65761
280280
enp.lazy.torch.utils._pytree._register_pytree_node( # pylint: disable=protected-access
281281
cls,
282-
flatten_fn=lambda a: a.tree_flatten(),
282+
flatten_fn=lambda a: a._tree_flatten(), # pylint: disable=protected-access
283283
unflatten_fn=lambda vals, ctx: cls.tree_unflatten(ctx, vals),
284284
)
285285
cls._dca_torch_tree_registered = True # pylint: disable=protected-access
@@ -295,7 +295,7 @@ def __post_init__(self) -> None:
295295

296296
# Validate the batch shape is consistent
297297
# However, we need to be careful that `_ArrayField` never uses
298-
# `@epy.cached_property`
298+
# `@functools.cached_property`
299299
shape = self._broadcast_shape_inplace()
300300

301301
# TODO(epot): When to validate (`field.validate()`)
@@ -581,7 +581,7 @@ def cuda(self: _DcT, *args, **kwargs) -> _DcT:
581581

582582
# ====== Internal ======
583583

584-
@epy.cached_property
584+
@functools.cached_property
585585
def _all_fields_empty(self) -> bool:
586586
"""Returns True if the `dataclass_array` is invalid."""
587587
if not self._array_fields: # All fields are `None` / `object`
@@ -601,9 +601,10 @@ def _all_fields_empty(self) -> bool:
601601
return True
602602
return False
603603

604-
@epy.cached_property
604+
@functools.cached_property
605605
def _all_array_fields(self) -> dict[str, _ArrayField]:
606606
"""All array fields, including `None` values."""
607+
assert self._dca_fields_metadata is not None
607608
return { # pylint: disable=g-complex-comprehension
608609
name: _ArrayField(
609610
name=name,
@@ -613,7 +614,7 @@ def _all_array_fields(self) -> dict[str, _ArrayField]:
613614
for name, field_metadata in self._dca_fields_metadata.items() # pylint: disable=protected-access
614615
}
615616

616-
@epy.cached_property
617+
@functools.cached_property
617618
def _array_fields(self) -> list[_ArrayField]:
618619
"""All active array fields (non-None), including static ones."""
619620
# Filter `None` values
@@ -767,10 +768,19 @@ def _apply_field_dn(f: _ArrayField):
767768
else:
768769
return self
769770

770-
def tree_flatten(self) -> tuple[tuple[DcOrArray, ...], _TreeMetadata]:
771+
def tree_flatten_with_keys(
772+
self,
773+
) -> tuple[
774+
tuple[tuple[enp.lazy.jax.tree_utils.BuiltInKeyEntry, DcOrArray], ...],
775+
_TreeMetadata,
776+
]:
771777
"""`jax.tree_utils` support."""
778+
tree_util = enp.lazy.jax.tree_util
772779
# We flatten all values (and not just the non-None ones)
773-
array_field_values = tuple(f.value for f in self._all_array_fields.values())
780+
array_field_values = tuple(
781+
(tree_util.GetAttrKey(k), f.value)
782+
for k, f in self._all_array_fields.items()
783+
)
774784
metadata = _TreeMetadata(
775785
array_field_names=list(self._all_array_fields.keys()),
776786
non_array_field_kwargs={
@@ -817,8 +827,13 @@ def tree_unflatten(
817827
self._setattr(k, v) # pylint: disable=protected-access
818828
return self
819829

830+
def _tree_flatten(self) -> tuple[tuple[DcOrArray, ...], _TreeMetadata]:
831+
components, metadata = self.tree_flatten_with_keys()
832+
components = tuple(v for _, v in components)
833+
return components, metadata
834+
820835
def __tf_flatten__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
821-
components, metadata = self.tree_flatten()
836+
components, metadata = self._tree_flatten()
822837
return metadata, components
823838

824839
@classmethod
@@ -1070,7 +1085,7 @@ def full_shape(self) -> DcOrArrayT:
10701085
# empty shapes to True `bool(shape) == True` when `shape=()`
10711086
return tuple(self.value.shape)
10721087

1073-
@epy.cached_property
1088+
@functools.cached_property
10741089
def inner_shape(self) -> Shape:
10751090
"""Returns the the static shape resolved for the current value."""
10761091
if not self.inner_shape_non_static:

0 commit comments

Comments
 (0)