Skip to content

Commit 0d34c7e

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Add tf.nest support to dca
PiperOrigin-RevId: 549246873
1 parent f6ed3d7 commit 0d34c7e

File tree

5 files changed

+131
-30
lines changed

5 files changed

+131
-30
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add `tf.nest` compatibility for `DataclassArray`.
27+
2628
## [1.4.2] - 2023-07-10
2729

2830
* Add `dca.concat` method in addition to `dca.stack`.

dataclass_array/array_dataclass.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,24 @@ def cuda(self: _DcT, *args, **kwargs) -> _DcT:
581581

582582
# ====== Internal ======
583583

584+
@epy.cached_property
585+
def _all_fields_empty(self) -> bool:
586+
"""Returns True if the `dataclass_array` is invalid."""
587+
if not self._array_fields: # All fields are `None` / `object`
588+
# No fields have been defined.
589+
# This can be the case internally by jax which apply some
590+
# `tree_map(lambda x: sentinel)`.
591+
return True
592+
593+
# `tf.nest` sometimes replace values by dummy `.` inside
594+
# `assert_same_structure`
595+
if enp.lazy.has_tf:
596+
from tensorflow.python.util import nest_util # pytype: disable=import-error # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
597+
598+
if any(f.value is nest_util._DOT for f in self._array_fields): # pylint: disable=protected-access,not-an-iterable
599+
return True
600+
return False
601+
584602
@epy.cached_property
585603
def _all_array_fields(self) -> dict[str, _ArrayField]:
586604
"""All array fields, including `None` values."""
@@ -603,7 +621,7 @@ def _array_fields(self) -> list[_ArrayField]:
603621

604622
def _cast_xnp_dtype_inplace(self) -> Optional[enp.NpModule]:
605623
"""Validate `xnp` are consistent and cast `np` -> `xnp` in-place."""
606-
if not self._array_fields: # All fields are `None` / `object`
624+
if self._all_fields_empty: # pylint: disable=using-constant-test
607625
return None
608626

609627
# Validate the dtype
@@ -627,12 +645,19 @@ def _get_xnp(f: _ArrayField) -> enp.NpModule:
627645

628646
def _cast_field(f: _ArrayField) -> None:
629647
try:
630-
new_value = np_utils.asarray(
631-
f.value,
632-
xnp=xnp,
633-
dtype=f.dtype,
634-
cast_dtype=self.__dca_params__.cast_dtype,
635-
)
648+
# Supports for TensorSpec (e.g. in `tf.function` signature)
649+
if enp.lazy.is_tf_xnp(xnp) and isinstance(
650+
f.value, enp.lazy.tf.TensorSpec
651+
):
652+
# TODO(epot): Actually check the dtype
653+
new_value = f.value
654+
else:
655+
new_value = np_utils.asarray(
656+
f.value,
657+
xnp=xnp,
658+
dtype=f.dtype,
659+
cast_dtype=self.__dca_params__.cast_dtype,
660+
)
636661
self._setattr(f.name, new_value)
637662
# After the field has been set, we validate the shape
638663
f.assert_shape()
@@ -648,9 +673,7 @@ def _cast_field(f: _ArrayField) -> None:
648673

649674
def _broadcast_shape_inplace(self) -> Optional[Shape]:
650675
"""Validate the shapes are consistent and broadcast values in-place."""
651-
if not self._array_fields: # No fields have been defined.
652-
# This can be the case internally by jax which apply some
653-
# `tree_map(lambda x: sentinel)`.
676+
if self._all_fields_empty: # pylint: disable=using-constant-test
654677
return None
655678

656679
# First collect all shapes and compute the final shape.
@@ -735,7 +758,7 @@ def _apply_field_dn(f: _ArrayField):
735758
else:
736759
return array_fn(f)
737760

738-
new_values = {f.name: _apply_field_dn(f) for f in self._array_fields} # pylint: disable=not-an-iterable
761+
new_values = {f.name: _apply_field_dn(f) for f in self._array_fields} # pylint: disable=not-an-iterable,protected-access
739762
# For performance, do not call replace to save the constructor call
740763
if not _inplace:
741764
return self.replace(**new_values)
@@ -792,6 +815,18 @@ def tree_unflatten(
792815
self._setattr(k, v) # pylint: disable=protected-access
793816
return self
794817

818+
def __tf_flatten__(self) -> tuple[_TreeMetadata, list[DcOrArray]]:
819+
components, metadata = self.tree_flatten()
820+
return metadata, components
821+
822+
@classmethod
823+
def __tf_unflatten__(
824+
cls: Type[_DcT],
825+
metadata: _TreeMetadata,
826+
components: list[DcOrArray],
827+
) -> _DcT:
828+
return cls.tree_unflatten(metadata, components)
829+
795830
def _setattr(self, name: str, value: Any) -> None:
796831
"""Like setattr, but support `frozen` dataclasses."""
797832
object.__setattr__(self, name, value)

dataclass_array/array_dataclass_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import dataclass_array as dca
2222
from dataclass_array.typing import FloatArray, IntArray, f32, i32 # pylint: disable=g-multiple-import
23-
from dataclass_array.typing import Shape # pylint: disable=g-multiple-import
23+
from dataclass_array.typing import Shape # pylint: disable=g-multiple-import,g-importing-member
2424
from etils import enp
2525
import numpy as np
2626
import pytest
@@ -585,6 +585,7 @@ def test_infer_np(xnp: enp.NpModule):
585585
@pytest.mark.parametrize(
586586
'tree_map',
587587
[
588+
enp.lazy.tf.nest.map_structure,
588589
enp.lazy.jax.tree_map,
589590
enp.lazy.torch.utils._pytree.tree_map,
590591
],
@@ -595,6 +596,17 @@ def test_tree_map(tree_map, dca_cls: DcaTest):
595596
dca_cls.assert_val(p, (1, 3), xnp=np)
596597

597598

599+
@pytest.mark.skip('tf.data fail currently') # TODO(epot): Restore
600+
def test_tf_data():
601+
ds = tf.data.Dataset.range(3)
602+
ds = ds.map(lambda x: Point(x=x, y=x))
603+
604+
for ex in ds:
605+
assert isinstance(ex, Point)
606+
assert ex.x.shape == () # pylint: disable=g-explicit-bool-comparison
607+
assert ex.y.shape == () # pylint: disable=g-explicit-bool-comparison
608+
609+
598610
def test_torch_device():
599611
p = Nested.make(shape=(2,), xnp=enp.lazy.torch)
600612
p = p.cpu()

dataclass_array/utils/inspect_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ class Signature(Generic[_OutT]):
5050
5151
y = bound_args.call() # Call the function
5252
```
53-
5453
"""
5554

5655
fn: _Fn[_OutT]
@@ -202,7 +201,7 @@ def map_bound_arg(
202201
self,
203202
fn: Callable[[BoundArg[_ArgT]], _NewArgT],
204203
) -> BoundArgs[_NewArgT, _OutT]:
205-
"""Apply validation/modification to the arguments value."""
204+
"""Apply validation/modification to all `BoundArg`."""
206205

207206
def _fn(arg: BoundArg[_ArgT]) -> _NewArgT: # pytype: disable=invalid-annotation
208207
try:
@@ -213,9 +212,14 @@ def _fn(arg: BoundArg[_ArgT]) -> _NewArgT: # pytype: disable=invalid-annotation
213212
prefix=f'For arg {arg.pos} ({arg.name!r}):\n',
214213
)
215214

215+
return self.replace_args_values({arg.name: _fn(arg) for arg in self})
216+
217+
def replace_args_values(
218+
self, new_values: dict[str, _NewArgT]
219+
) -> BoundArgs[_NewArgT, _OutT]:
216220
bound_args = inspect.BoundArguments(
217221
signature=self.bound_args.signature,
218-
arguments={arg.name: _fn(arg) for arg in self}, # pytype: disable=wrong-arg-types
222+
arguments=new_values, # pytype: disable=wrong-arg-types
219223
)
220224
return self.replace(bound_args=bound_args) # pytype: disable=attribute-error
221225

dataclass_array/vectorization.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from dataclass_array import array_dataclass
2424
from dataclass_array import ops
25-
from dataclass_array.typing import DcOrArray, Shape # pylint: disable=g-multiple-import
25+
from dataclass_array.typing import DcOrArray, Shape # pylint: disable=g-multiple-import,g-importing-member
2626
from dataclass_array.utils import inspect_utils
2727
from dataclass_array.utils import np_utils
2828
from dataclass_array.utils import py_utils
@@ -307,21 +307,30 @@ def _vmap_method(
307307
xnp: enp.NpModule,
308308
) -> _Out:
309309
"""Vectorize self using the `xnp` backend. Assume `self` was flatten."""
310-
if xnp is enp.lazy.np:
310+
is_jax = enp.lazy.is_jax_xnp(xnp)
311+
is_torch = enp.lazy.is_torch_xnp(xnp)
312+
313+
if enp.lazy.is_np_xnp(xnp):
311314
return _vmap_method_np(args, map_non_static=map_non_static)
312-
elif xnp is enp.lazy.jnp:
315+
elif is_jax or is_torch:
316+
if is_jax:
317+
make_vmap_fn = _jax_vmap_cached
318+
elif is_torch:
319+
make_vmap_fn = _torch_vmap_cached
320+
else:
321+
raise ValueError('Unexpected')
313322
return _vmap_method_jax_torch(
314323
args,
315324
map_non_static=map_non_static,
316-
make_vmap_fn=_jax_vmap_cached,
325+
make_vmap_fn=make_vmap_fn,
317326
)
318-
elif xnp is enp.lazy.tnp:
319-
return _vmap_method_tf(args, map_non_static=map_non_static)
320-
elif xnp is enp.lazy.torch:
321-
return _vmap_method_jax_torch(
322-
args,
323-
map_non_static=map_non_static,
324-
make_vmap_fn=_torch_vmap_cached,
327+
elif enp.lazy.is_tf_xnp(xnp):
328+
# return _vmap_method_tf(args, map_non_static=map_non_static)
329+
330+
# TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting
331+
raise NotImplementedError(
332+
'vectorization not supported in TF yet due to lack of `tf.nest` '
333+
'support. Please upvote or comment b/152678472.'
325334
)
326335
raise TypeError(f'Invalid numpy module: {xnp}')
327336

@@ -400,13 +409,52 @@ def _vmap_method_tf(
400409
map_non_static: _MapNonStatic,
401410
) -> _OutT:
402411
"""vectorization using `tf` backend."""
403-
# TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting
404-
raise NotImplementedError(
405-
'vectorization not supported in TF yet due to lack of `tf.nest` '
406-
'support. Please upvote or comment b/152678472.'
412+
413+
# Flatten args
414+
415+
args_info = args.map(lambda _: None)
416+
# ... except the non-static ones
417+
args_info = map_non_static(lambda _: 0, args_info)
418+
419+
# Split args in static/non-static
420+
static_args = {}
421+
nonstatic_args = {}
422+
for a, ai in zip(args, args_info):
423+
assert a.name == ai.name
424+
if ai.value is None:
425+
static_args[a.name] = a.value
426+
else:
427+
nonstatic_args[a.name] = a.value
428+
429+
def new_fn(non_statics, statics):
430+
# Merge args and call the function
431+
new_args = args.replace_args_values(dict(**non_statics, **statics))
432+
return new_args.call()
433+
434+
# `vectorized_map(` uses autograph, which fails, so use tf.map_fn instead
435+
return _better_map_fn( #
436+
functools.partial(new_fn, statics=static_args),
437+
nonstatic_args,
407438
)
408439

409440

441+
# tf.map_fn do not support different output signature:
442+
def _better_map_fn(fn, elems, **kwargs):
443+
"""Like `tf.map_fn`."""
444+
tf = enp.lazy.tf
445+
if 'fn_output_signature' not in kwargs:
446+
elem_spec = tf.nest.map_structure(
447+
lambda t: tf.type_spec_from_value(t)._unbatch(), elems # pylint: disable=protected-access
448+
)
449+
output_spec = tf.nest.map_structure(
450+
tf.type_spec_from_value,
451+
tf.function(fn).get_concrete_function(elem_spec).structured_outputs,
452+
)
453+
kwargs['fn_output_signature'] = output_spec
454+
455+
return tf.map_fn(fn, elems, **kwargs)
456+
457+
410458
def _stack(*vals: _OutT) -> _OutT:
411459
"""Stack the given tree."""
412460
assert vals

0 commit comments

Comments
 (0)