2020
2121from absl .testing import absltest
2222from absl .testing import parameterized
23- import chex
2423from flax import jax_utils
2524import jax
2625import jax .numpy as jnp
2928NDEV = 4
3029
3130
32- class PadShardUnpadTest (chex .TestCase ):
31+ def assert_max_traces (n ):
32+ """Decorator to assert that a function is traced at most n times."""
33+ from functools import wraps
34+ def decorator (fn ):
35+ trace_count = {'count' : 0 }
36+
37+ @wraps (fn )
38+ def wrapped (* args , ** kwargs ):
39+ trace_count ['count' ] += 1
40+ if trace_count ['count' ] > n :
41+ raise AssertionError (
42+ f"Function was traced { trace_count ['count' ]} times, "
43+ f"expected at most { n } traces"
44+ )
45+ return fn (* args , ** kwargs )
46+
47+ wrapped .trace_count = trace_count
48+ return wrapped
49+ return decorator
50+
51+
52+ class PadShardUnpadTest (parameterized .TestCase ):
3353 BATCH_SIZES = [NDEV , NDEV + 1 , NDEV - 1 , 5 * NDEV , 5 * NDEV + 1 , 5 * NDEV - 1 ]
3454 DTYPES = [np .float32 , np .uint8 , jax .numpy .bfloat16 , np .int32 ]
3555
36- def tearDown (self ):
37- chex .clear_trace_counter ()
38- super ().tearDown ()
56+
3957
4058 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
4159 def test_basics (self , dtype , bs ):
@@ -47,7 +65,7 @@ def add(a, b):
4765
4866 x = np .arange (bs , dtype = dtype )
4967 y = add (x , 10 * x )
50- chex . assert_type (y .dtype , x .dtype )
68+ self . assertEqual (y .dtype , x .dtype )
5169 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
5270
5371 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -59,24 +77,22 @@ def add(a, b):
5977
6078 x = jnp .arange (bs , dtype = dtype )
6179 y = add (dict (a = x ), (10 * x ,))
62- chex . assert_type (y .dtype , x .dtype )
80+ self . assertEqual (y .dtype , x .dtype )
6381 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
6482
6583 @parameterized .parameters (DTYPES )
6684 def test_min_device_batch_avoids_recompile (self , dtype ):
6785 @partial (jax_utils .pad_shard_unpad , static_argnums = ())
6886 @jax .jit
69- @chex . assert_max_traces (n = 1 )
87+ @assert_max_traces (n = 1 )
7088 def add (a , b ):
7189 b = jnp .asarray (b , dtype = dtype )
7290 return a + b
7391
74- chex .clear_trace_counter ()
75-
7692 for bs in self .BATCH_SIZES :
7793 x = jnp .arange (bs , dtype = dtype )
7894 y = add (x , 10 * x , min_device_batch = 9 ) # pylint: disable=unexpected-keyword-arg
79- chex . assert_type (y .dtype , x .dtype )
95+ self . assertEqual (y .dtype , x .dtype )
8096 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
8197
8298 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -87,7 +103,7 @@ def add(a, b):
87103
88104 x = jnp .arange (bs , dtype = dtype )
89105 y = add (x , 10 )
90- chex . assert_type (y .dtype , x .dtype )
106+ self . assertEqual (y .dtype , x .dtype )
91107 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 ))
92108
93109 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -102,7 +118,7 @@ def add(params, a, *, b):
102118
103119 x = jnp .arange (bs , dtype = dtype )
104120 y = add (5 , x , b = 10 )
105- chex . assert_type (y .dtype , x .dtype )
121+ self . assertEqual (y .dtype , x .dtype )
106122 np .testing .assert_allclose (np .float64 (y ), np .float64 (5 * x + 10 ))
107123
108124
0 commit comments