File tree Expand file tree Collapse file tree 3 files changed +4
-10
lines changed
Expand file tree Collapse file tree 3 files changed +4
-10
lines changed Original file line number Diff line number Diff line change @@ -23,6 +23,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
2323
2424## [ Unreleased]
2525
26+ * Add ` dca.vectorize_method ` compatibility for ` tf.nest ` /` tf.data ` .
27+
2628## [ 1.5.0] - 2023-07-10
2729
2830* Add ` tf.nest ` /` tf.data ` compatibility for ` DataclassArray ` .
Original file line number Diff line number Diff line change @@ -88,10 +88,7 @@ def assert_array_equal(
8888def skip_vmap_unavailable (xnp : enp .NpModule , * , skip_torch : str = '' ) -> None :
8989 """Skip the test when vmap not available."""
9090 skip = False
91- if enp .lazy .is_tf_xnp (xnp ):
92- # TODO(b/152678472): TF do not support vmap & tf.nest
93- skip = True
94- elif enp .lazy .is_torch_xnp (xnp ):
91+ if enp .lazy .is_torch_xnp (xnp ):
9592 if skip_torch :
9693 skip = True
9794 if skip :
Original file line number Diff line number Diff line change @@ -325,13 +325,8 @@ def _vmap_method(
325325 make_vmap_fn = make_vmap_fn ,
326326 )
327327 elif enp .lazy .is_tf_xnp (xnp ):
328- # return _vmap_method_tf(args, map_non_static=map_non_static)
328+ return _vmap_method_tf (args , map_non_static = map_non_static )
329329
330- # TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting
331- raise NotImplementedError (
332- 'vectorization not supported in TF yet due to lack of `tf.nest` '
333- 'support. Please upvote or comment b/152678472.'
334- )
335330 raise TypeError (f'Invalid numpy module: { xnp } ' )
336331
337332
You can’t perform that action at this time.
0 commit comments