4141floatx = dtypes .canonicalize_dtype (jnp .float64 )
4242
4343
44+ def _assert_ragged_equal_with_elementwise_mask (
45+ row_count , col_grid_size , ragged_shape , res , ref
46+ ):
47+ total_columns = col_grid_size * 128
48+ mask = jnp .zeros ((len (ragged_shape ), row_count , total_columns ), dtype = bool )
49+
50+ for i , r in enumerate (ragged_shape ):
51+ mask = mask .at [i , :, : r * 128 ].set (True )
52+
53+ res_valid = jnp .where (mask , res , - 1 )
54+ ref_valid = jnp .where (mask , ref , - 1 )
55+
56+ np .testing .assert_allclose (res_valid , ref_valid )
57+
58+
4459@jtu .with_config (jax_traceback_filtering = "off" )
4560class PallasBaseTest (jtu .JaxTestCase ):
4661 INTERPRET = False
@@ -104,24 +119,16 @@ def invoke_kernel(x):
104119 axis_size = 3 ,
105120 )(x )
106121
107- res = res .data
108- total = len (ragged_shape ) * row_count * col_grid_size * 128
109- res_total = np .prod (res .shape )
110- self .assertEqual (res_total , total )
111- ragged_total = 0
112- for dim in ragged_shape :
113- ragged_total += row_count * dim * 128
114-
115- def correct (v ):
116- return np .count_nonzero (v == jnp .sin (1.0 ))
117-
118- for b , batch in enumerate (res ):
119- ragged_val = ragged_shape [b ]
120- for r , row in enumerate (batch ):
121- row_total = ragged_val * 128
122- self .assertEqual (correct (row ), row_total , msg = f"row { r } , : { row } " )
122+ ref = jax .vmap (
123+ jnp .sin ,
124+ out_axes = batching .jumble_axis ,
125+ in_axes = batching .jumble_axis ,
126+ axis_size = 3 ,
127+ )(x )
123128
124- self .assertEqual (correct (res ), ragged_total )
129+ _assert_ragged_equal_with_elementwise_mask (
130+ row_count , col_grid_size , ragged_shape , res .data , ref .data
131+ )
125132
126133 def test_vmap_jumble_over_add_kernel (self ):
127134 if not jtu .test_device_matches (["tpu" ]):
@@ -156,36 +163,34 @@ def invoke_kernel(x, y):
156163 (8 , col_grid_size * 128 ), dtype = jnp .float32
157164 ),
158165 grid = (1 , col_grid_size ),
159- interpret = False ,
166+ interpret = self . INTERPRET ,
160167 )(x , y )
161168
162- res = jax .vmap (
163- invoke_kernel ,
164- out_axes = batching .jumble_axis ,
165- in_axes = batching .jumble_axis ,
166- axis_size = 3 ,
167- )(x , y )
169+ # We've had this test fail with data corruption due to multiple
170+ # invocations, so we run it k times to make sure it's not setting up
171+ # memory incorrectly for subsequent invocations.
172+ for _ in range (4 ):
173+ res = jax .vmap (
174+ invoke_kernel ,
175+ out_axes = batching .jumble_axis ,
176+ in_axes = batching .jumble_axis ,
177+ axis_size = 3 ,
178+ )(x , y )
168179
169- res = res .data
170- total = len (ragged_shape ) * row_count * col_grid_size * 128
171- res_total = np .prod (res .shape )
172- self .assertEqual (res_total , total )
173- ragged_total = 0
174- for dim in ragged_shape :
175- ragged_total += row_count * dim * 128
176-
177- def correct (v ):
178- return np .count_nonzero (v == 2.0 )
179-
180- for r , row in enumerate (res ):
181- ragged_val = ragged_shape [r ]
182- row_total = ragged_val * 128 * row_count
183- self .assertEqual (correct (row ), row_total )
184- for col in row :
185- col_total = ragged_val * 128
186- self .assertEqual (correct (col ), col_total )
187-
188- self .assertEqual (np .count_nonzero (res == 2.0 ), ragged_total )
180+ res = res .data
181+ total = len (ragged_shape ) * row_count * col_grid_size * 128
182+ res_total = np .prod (res .shape )
183+ self .assertEqual (res_total , total )
184+
185+ ref = jax .vmap (
186+ lambda x , y : x + y ,
187+ out_axes = batching .jumble_axis ,
188+ in_axes = batching .jumble_axis ,
189+ axis_size = 3 ,
190+ )(x , y )
191+ _assert_ragged_equal_with_elementwise_mask (
192+ row_count , col_grid_size , ragged_shape , res , ref .data
193+ )
189194
190195 def test_vmap_jumble_over_sin_kernel_grid_remapping (self ):
191196 if not jtu .test_device_matches (["tpu" ]):
@@ -212,7 +217,7 @@ def invoke_kernel(x):
212217 out_specs = pl .BlockSpec ((8 , 128 ), lambda j , k : (j , k )),
213218 out_shape = jax .ShapeDtypeStruct ((8 , 640 ), dtype = jnp .float32 ),
214219 grid = (1 , 5 ),
215- interpret = False ,
220+ interpret = self . INTERPRET ,
216221 )(x )
217222
218223 with self .assertRaisesRegex (ValueError , "Axis 2 is out of bounds for grid" ):
@@ -280,7 +285,7 @@ def matmul(
280285 ),
281286 grid = grid ,
282287 input_output_aliases = {2 : 0 },
283- interpret = False ,
288+ interpret = self . INTERPRET ,
284289 )(x , y , x_sentinel )
285290
286291 # TODO(mvoz): parameterize this shape?
0 commit comments