2121
2222from dataclass_array import array_dataclass
2323from dataclass_array .typing import FloatArray # pylint: disable=g-multiple-import
24+ from etils import enp
2425from etils .etree import jax as etree
2526from etils .etree import Tree
2627import numpy as np
@@ -61,6 +62,9 @@ def assert_allclose(
6162def assert_array_equal (
6263 x ,
6364 y ,
65+ * ,
66+ atol : Optional [float ] = None ,
67+ rtol : Optional [float ] = None ,
6468) -> None :
6569 """Assert the 2 objects are equals.
6670
@@ -71,9 +75,26 @@ def assert_array_equal(
7175 Args:
7276 x: First element to compare
7377 y: Second element to compare
78+ atol: Absolute tolerance
79+ rtol: Relative tolerance
7480 """
7581 assert type (x ) == type (y ) # pylint: disable=unidiomatic-typecheck
7682 assert x .shape == y .shape
77- assert_allclose (x , y )
83+ assert_allclose (x , y , atol = atol , rtol = rtol )
7884 if isinstance (x , array_dataclass .DataclassArray ):
7985 assert x .xnp is y .xnp
86+
87+
88+ def skip_vmap_unavailable (xnp : enp .NpModule , * , skip_torch : str = '' ) -> None :
89+ """Skip the test when vmap not available."""
90+ skip = False
91+ if enp .lazy .is_tf_xnp (xnp ):
92+ # TODO(b/152678472): TF do not support vmap & tf.nest
93+ skip = True
94+ elif enp .lazy .is_torch_xnp (xnp ):
95+ if skip_torch :
96+ skip = True
97+ if skip :
98+ import pytest # pylint: disable=g-import-not-at-top # pytype: disable=import-error
99+
100+ pytest .skip ('Vectorization not supported yet with TF / Torch' )
0 commit comments