Skip to content

Commit 84685ab

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Fix tf.data compatibility for DataclassArray
PiperOrigin-RevId: 549577729
1 parent 525305f commit 84685ab

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26-
* Add `tf.nest` compatibility for `DataclassArray`.
26+
* Add `tf.nest`/`tf.data` compatibility for `DataclassArray`.
2727

2828
## [1.4.2] - 2023-07-10
2929

dataclass_array/array_dataclass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,10 +765,10 @@ def _apply_field_dn(f: _ArrayField):
765765
else:
766766
return self
767767

768-
def tree_flatten(self) -> tuple[list[DcOrArray], _TreeMetadata]:
768+
def tree_flatten(self) -> tuple[tuple[DcOrArray, ...], _TreeMetadata]:
769769
"""`jax.tree_utils` support."""
770770
# We flatten all values (and not just the non-None ones)
771-
array_field_values = [f.value for f in self._all_array_fields.values()]
771+
array_field_values = tuple(f.value for f in self._all_array_fields.values())
772772
metadata = _TreeMetadata(
773773
array_field_names=list(self._all_array_fields.keys()),
774774
non_array_field_kwargs={
@@ -815,7 +815,7 @@ def tree_unflatten(
815815
self._setattr(k, v) # pylint: disable=protected-access
816816
return self
817817

818-
def __tf_flatten__(self) -> tuple[_TreeMetadata, list[DcOrArray]]:
818+
def __tf_flatten__(self) -> tuple[_TreeMetadata, tuple[DcOrArray, ...]]:
819819
components, metadata = self.tree_flatten()
820820
return metadata, components
821821

dataclass_array/array_dataclass_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,6 @@ def test_tree_map(tree_map, dca_cls: DcaTest):
596596
dca_cls.assert_val(p, (1, 3), xnp=np)
597597

598598

599-
@pytest.mark.skip('tf.data fail currently') # TODO(epot): Restore
600599
def test_tf_data():
601600
ds = tf.data.Dataset.range(3)
602601
ds = ds.map(lambda x: Point(x=x, y=x))

0 commit comments

Comments
 (0)