Skip to content

Commit 237cf93

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Activate test for TF and dca.vectorize_method support
PiperOrigin-RevId: 561646634
1 parent 3f5dc03 commit 237cf93

File tree

3 files changed

+4
-10
lines changed

3 files changed

+4
-10
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323

2424
## [Unreleased]
2525

26+
* Add `dca.vectorize_method` compatibility for `tf.nest`/`tf.data`.
27+
2628
## [1.5.0] - 2023-07-10
2729

2830
* Add `tf.nest`/`tf.data` compatibility for `DataclassArray`.

dataclass_array/testing.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,7 @@ def assert_array_equal(
8888
def skip_vmap_unavailable(xnp: enp.NpModule, *, skip_torch: str = '') -> None:
8989
"""Skip the test when vmap not available."""
9090
skip = False
91-
if enp.lazy.is_tf_xnp(xnp):
92-
# TODO(b/152678472): TF do not support vmap & tf.nest
93-
skip = True
94-
elif enp.lazy.is_torch_xnp(xnp):
91+
if enp.lazy.is_torch_xnp(xnp):
9592
if skip_torch:
9693
skip = True
9794
if skip:

dataclass_array/vectorization.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,8 @@ def _vmap_method(
325325
make_vmap_fn=make_vmap_fn,
326326
)
327327
elif enp.lazy.is_tf_xnp(xnp):
328-
# return _vmap_method_tf(args, map_non_static=map_non_static)
328+
return _vmap_method_tf(args, map_non_static=map_non_static)
329329

330-
# TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting
331-
raise NotImplementedError(
332-
'vectorization not supported in TF yet due to lack of `tf.nest` '
333-
'support. Please upvote or comment b/152678472.'
334-
)
335330
raise TypeError(f'Invalid numpy module: {xnp}')
336331

337332

0 commit comments

Comments
 (0)