2020
2121from absl .testing import absltest
2222from absl .testing import parameterized
23- import chex
2423from flax import jax_utils
2524import jax
2625import jax .numpy as jnp
2928NDEV = 4
3029
3130
32- class PadShardUnpadTest (chex .TestCase ):
31+ def assert_max_traces (n ):
32+ """Decorator to assert that a function is traced at most n times."""
33+ def decorator (fn ):
34+ trace_count = {'count' : 0 }
35+
36+ def wrapped (* args , ** kwargs ):
37+ trace_count ['count' ] += 1
38+ if trace_count ['count' ] > n :
39+ raise AssertionError (
40+ f"Function was traced { trace_count ['count' ]} times, "
41+ f"expected at most { n } traces"
42+ )
43+ return fn (* args , ** kwargs )
44+
45+ wrapped .trace_count = trace_count
46+ return wrapped
47+ return decorator
48+
49+
50+ class PadShardUnpadTest (parameterized .TestCase ):
3351 BATCH_SIZES = [NDEV , NDEV + 1 , NDEV - 1 , 5 * NDEV , 5 * NDEV + 1 , 5 * NDEV - 1 ]
3452 DTYPES = [np .float32 , np .uint8 , jax .numpy .bfloat16 , np .int32 ]
3553
36- def tearDown (self ):
37- chex .clear_trace_counter ()
38- super ().tearDown ()
54+
3955
4056 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
4157 def test_basics (self , dtype , bs ):
@@ -47,7 +63,7 @@ def add(a, b):
4763
4864 x = np .arange (bs , dtype = dtype )
4965 y = add (x , 10 * x )
50- chex . assert_type (y .dtype , x .dtype )
66+ self . assertEqual (y .dtype , x .dtype )
5167 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
5268
5369 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -59,24 +75,22 @@ def add(a, b):
5975
6076 x = jnp .arange (bs , dtype = dtype )
6177 y = add (dict (a = x ), (10 * x ,))
62- chex . assert_type (y .dtype , x .dtype )
78+ self . assertEqual (y .dtype , x .dtype )
6379 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
6480
6581 @parameterized .parameters (DTYPES )
6682 def test_min_device_batch_avoids_recompile (self , dtype ):
6783 @partial (jax_utils .pad_shard_unpad , static_argnums = ())
6884 @jax .jit
69- @chex . assert_max_traces (n = 1 )
85+ @assert_max_traces (n = 1 )
7086 def add (a , b ):
7187 b = jnp .asarray (b , dtype = dtype )
7288 return a + b
7389
74- chex .clear_trace_counter ()
75-
7690 for bs in self .BATCH_SIZES :
7791 x = jnp .arange (bs , dtype = dtype )
7892 y = add (x , 10 * x , min_device_batch = 9 ) # pylint: disable=unexpected-keyword-arg
79- chex . assert_type (y .dtype , x .dtype )
93+ self . assertEqual (y .dtype , x .dtype )
8094 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 * x ))
8195
8296 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -87,7 +101,7 @@ def add(a, b):
87101
88102 x = jnp .arange (bs , dtype = dtype )
89103 y = add (x , 10 )
90- chex . assert_type (y .dtype , x .dtype )
104+ self . assertEqual (y .dtype , x .dtype )
91105 np .testing .assert_allclose (np .float64 (y ), np .float64 (x + 10 ))
92106
93107 @parameterized .product (dtype = DTYPES , bs = BATCH_SIZES )
@@ -102,7 +116,7 @@ def add(params, a, *, b):
102116
103117 x = jnp .arange (bs , dtype = dtype )
104118 y = add (5 , x , b = 10 )
105- chex . assert_type (y .dtype , x .dtype )
119+ self . assertEqual (y .dtype , x .dtype )
106120 np .testing .assert_allclose (np .float64 (y ), np .float64 (5 * x + 10 ))
107121
108122
0 commit comments