Skip to content

Commit 84f3f99

Browse files
[pallas] fix jumble test flakiness
* Enable interpret mode in tests * Ensure that the kernel is run multiple times where weve seen data corruption * Use masked comparison - prior comparison was reading garbage data as we were basically relying on past behavior of how uninitialized memory was behaving. * This was being hidden by a cache, where the interpret test, which always has 0.0 for uninitialized memory was being hit first, where TPU does not have the same behavior. PiperOrigin-RevId: 703272002
1 parent 651ab18 commit 84f3f99

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

tests/pallas/pallas_jumble_test.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,21 @@
4141
floatx = dtypes.canonicalize_dtype(jnp.float64)
4242

4343

44+
def _assert_ragged_equal_with_elementwise_mask(
45+
row_count, col_grid_size, ragged_shape, res, ref
46+
):
47+
total_columns = col_grid_size * 128
48+
mask = jnp.zeros((len(ragged_shape), row_count, total_columns), dtype=bool)
49+
50+
for i, r in enumerate(ragged_shape):
51+
mask = mask.at[i, :, : r * 128].set(True)
52+
53+
res_valid = jnp.where(mask, res, -1)
54+
ref_valid = jnp.where(mask, ref, -1)
55+
56+
np.testing.assert_allclose(res_valid, ref_valid)
57+
58+
4459
@jtu.with_config(jax_traceback_filtering="off")
4560
class PallasBaseTest(jtu.JaxTestCase):
4661
INTERPRET = False
@@ -104,24 +119,16 @@ def invoke_kernel(x):
104119
axis_size=3,
105120
)(x)
106121

107-
res = res.data
108-
total = len(ragged_shape) * row_count * col_grid_size * 128
109-
res_total = np.prod(res.shape)
110-
self.assertEqual(res_total, total)
111-
ragged_total = 0
112-
for dim in ragged_shape:
113-
ragged_total += row_count * dim * 128
114-
115-
def correct(v):
116-
return np.count_nonzero(v == jnp.sin(1.0))
117-
118-
for b, batch in enumerate(res):
119-
ragged_val = ragged_shape[b]
120-
for r, row in enumerate(batch):
121-
row_total = ragged_val * 128
122-
self.assertEqual(correct(row), row_total, msg=f"row {r}, : {row}")
122+
ref = jax.vmap(
123+
jnp.sin,
124+
out_axes=batching.jumble_axis,
125+
in_axes=batching.jumble_axis,
126+
axis_size=3,
127+
)(x)
123128

124-
self.assertEqual(correct(res), ragged_total)
129+
_assert_ragged_equal_with_elementwise_mask(
130+
row_count, col_grid_size, ragged_shape, res.data, ref.data
131+
)
125132

126133
def test_vmap_jumble_over_add_kernel(self):
127134
if not jtu.test_device_matches(["tpu"]):
@@ -156,36 +163,34 @@ def invoke_kernel(x, y):
156163
(8, col_grid_size * 128), dtype=jnp.float32
157164
),
158165
grid=(1, col_grid_size),
159-
interpret=False,
166+
interpret=self.INTERPRET,
160167
)(x, y)
161168

162-
res = jax.vmap(
163-
invoke_kernel,
164-
out_axes=batching.jumble_axis,
165-
in_axes=batching.jumble_axis,
166-
axis_size=3,
167-
)(x, y)
169+
# We've had this test fail with data corruption due to multiple
170+
# invocations, so we run it k times to make sure it's not setting up
171+
# memory incorrectly for subsequent invocations.
172+
for _ in range(4):
173+
res = jax.vmap(
174+
invoke_kernel,
175+
out_axes=batching.jumble_axis,
176+
in_axes=batching.jumble_axis,
177+
axis_size=3,
178+
)(x, y)
168179

169-
res = res.data
170-
total = len(ragged_shape) * row_count * col_grid_size * 128
171-
res_total = np.prod(res.shape)
172-
self.assertEqual(res_total, total)
173-
ragged_total = 0
174-
for dim in ragged_shape:
175-
ragged_total += row_count * dim * 128
176-
177-
def correct(v):
178-
return np.count_nonzero(v == 2.0)
179-
180-
for r, row in enumerate(res):
181-
ragged_val = ragged_shape[r]
182-
row_total = ragged_val * 128 * row_count
183-
self.assertEqual(correct(row), row_total)
184-
for col in row:
185-
col_total = ragged_val * 128
186-
self.assertEqual(correct(col), col_total)
187-
188-
self.assertEqual(np.count_nonzero(res == 2.0), ragged_total)
180+
res = res.data
181+
total = len(ragged_shape) * row_count * col_grid_size * 128
182+
res_total = np.prod(res.shape)
183+
self.assertEqual(res_total, total)
184+
185+
ref = jax.vmap(
186+
lambda x, y: x + y,
187+
out_axes=batching.jumble_axis,
188+
in_axes=batching.jumble_axis,
189+
axis_size=3,
190+
)(x, y)
191+
_assert_ragged_equal_with_elementwise_mask(
192+
row_count, col_grid_size, ragged_shape, res, ref.data
193+
)
189194

190195
def test_vmap_jumble_over_sin_kernel_grid_remapping(self):
191196
if not jtu.test_device_matches(["tpu"]):
@@ -212,7 +217,7 @@ def invoke_kernel(x):
212217
out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)),
213218
out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32),
214219
grid=(1, 5),
215-
interpret=False,
220+
interpret=self.INTERPRET,
216221
)(x)
217222

218223
with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"):
@@ -280,7 +285,7 @@ def matmul(
280285
),
281286
grid=grid,
282287
input_output_aliases={2: 0},
283-
interpret=False,
288+
interpret=self.INTERPRET,
284289
)(x, y, x_sentinel)
285290

286291
# TODO(mvoz): parameterize this shape?

0 commit comments

Comments
 (0)