Skip to content

Commit 3fc2371

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
Make gmm TPU kernel tests significantly cheaper
We were testing lots of very similar cases that did not really help a lot with coverage. PiperOrigin-RevId: 707115030
1 parent 0ec902d commit 3fc2371

File tree

1 file changed

+15
-42
lines changed

1 file changed

+15
-42
lines changed

tests/pallas/tpu_gmm_test.py

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,9 @@
4848
)
4949
hp.settings.load_profile("deterministic")
5050

51-
5251
def seed_strategy() -> hps.SearchStrategy[int]:
5352
return hps.integers(min_value=0, max_value=4)
5453

55-
5654
@hps.composite
5755
def group_strategy(
5856
draw: hps.DrawFn,
@@ -73,7 +71,6 @@ def group_strategy(
7371
)
7472
return num_groups, group_stride
7573

76-
7774
@hps.composite
7875
def group_sizes_strategy(
7976
draw: hps.DrawFn, m: int, num_groups: int
@@ -97,19 +94,12 @@ def group_sizes_strategy(
9794
starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final])
9895
return jnp.array(ends - starts, dtype=jnp.int32)
9996

100-
10197
GROUPED_MATMUL_TESTS = (
102-
(128, 128, 128),
103-
(256, 128, 128),
104-
(128, 256, 128),
105-
(128, 128, 256),
106-
(256, 128, 512),
107-
(512, 128, 128),
108-
(512, 2048, 128),
98+
(128, 128, 128), # Small
99+
(512, 2048, 256), # Big
109100
(128, 8, 16), # Test partial tiles.
110101
)
111102

112-
113103
def random_dense(
114104
shape: tuple[int, ...],
115105
key: jax.Array,
@@ -121,7 +111,6 @@ def random_dense(
121111
x = jax.random.uniform(key, shape, dtype, minval=-limit, maxval=limit) # pylint: disable=invalid-unary-operand-type
122112
return x.astype(jnp.bfloat16).astype(dtype)
123113

124-
125114
def dot(
126115
lhs: jnp.ndarray,
127116
rhs: jnp.ndarray,
@@ -133,7 +122,6 @@ def dot(
133122
rhs = jnp.transpose(rhs) if transpose_rhs else rhs
134123
return jax.lax.dot(lhs, rhs, preferred_element_type=preferred_element_type)
135124

136-
137125
def reference_gmm(
138126
lhs: jnp.ndarray,
139127
rhs: jnp.ndarray,
@@ -154,7 +142,6 @@ def reference_gmm(
154142
start += group_sizes[i]
155143
return jnp.concatenate(out, axis=0)
156144

157-
158145
def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
159146
dtypes = [jnp.float32, jnp.bfloat16]
160147

@@ -164,7 +151,6 @@ def with_dtype_arguments(xs: tuple[Any, ...]) -> tuple[Any, ...]:
164151
result.append(x + dtypes_tuple)
165152
return tuple(result)
166153

167-
168154
def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
169155
flags = [False, True]
170156
result = []
@@ -173,7 +159,6 @@ def with_transpose_argument(xs: tuple[Any, ...]) -> tuple[Any, ...]:
173159
result.append(x + (flag,))
174160
return tuple(result)
175161

176-
177162
def tolerances(
178163
lhs_dtype: jnp.dtype, rhs_dtype: jnp.dtype, out_dtype: jnp.dtype
179164
) -> tuple[float, float]:
@@ -185,7 +170,6 @@ def tolerances(
185170
return 1e-3, 1e-2 # atol, rtol
186171
return 1e-3, 1e-5 # atol, rtol
187172

188-
189173
# TODO(tgale): Fix errors with strict dtype promotion.
190174
@jtu.with_config(jax_numpy_dtype_promotion="standard")
191175
class GroupedMatmulTest(jtu.JaxTestCase):
@@ -218,15 +202,16 @@ def gmm_test(
218202
m: int,
219203
k: int,
220204
n: int,
221-
lhs_dtype: jnp.dtype,
222-
rhs_dtype: jnp.dtype,
223-
out_dtype: jnp.dtype,
224-
transpose_rhs: bool,
225205
data: hps.SearchStrategy[hps.DataObject],
226206
interpret: bool = False,
227207
):
228208
seed = data.draw(seed_strategy())
229209
num_groups, _ = data.draw(group_strategy(max_stride=1))
210+
lhs_dtype, rhs_dtype, out_dtype = [
211+
data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
212+
for _ in range(3)
213+
]
214+
transpose_rhs = data.draw(hps.booleans())
230215

231216
key = jax.random.key(seed)
232217
k1, k2 = jax.random.split(key, 2)
@@ -270,64 +255,52 @@ def reference_fn(lhs, rhs, group_sizes, preferred_element_type):
270255
self.assert_allclose(grad_lhs, expected_grad_lhs, atol=atol, rtol=rtol)
271256
self.assert_allclose(grad_rhs, expected_grad_rhs, atol=atol, rtol=rtol)
272257

273-
@parameterized.parameters(
274-
*with_transpose_argument(with_dtype_arguments(GROUPED_MATMUL_TESTS))
275-
)
258+
@parameterized.parameters(*GROUPED_MATMUL_TESTS)
276259
@hp.given(hps.data())
277260
def test_gmm(
278261
self,
279262
m: int,
280263
k: int,
281264
n: int,
282-
lhs_dtype: jnp.dtype,
283-
rhs_dtype: jnp.dtype,
284-
out_dtype: jnp.dtype,
285-
transpose_rhs: bool,
286265
data: hps.SearchStrategy[hps.DataObject],
287266
):
288-
self.gmm_test(m, k, n, lhs_dtype, rhs_dtype, out_dtype, transpose_rhs, data)
267+
self.gmm_test(m, k, n, data)
289268

290269
# NOTE: Run fewer tests with interpret mode. We just want to sanity check that
291270
# changes do not break running these kernels with interpret=True.
292-
@parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS[0:1]))
271+
@parameterized.parameters(*GROUPED_MATMUL_TESTS[0:1])
293272
@hp.given(hps.data())
294273
def test_gmm_interpret(
295274
self,
296275
m: int,
297276
k: int,
298277
n: int,
299-
lhs_dtype: jnp.dtype,
300-
rhs_dtype: jnp.dtype,
301-
out_dtype: jnp.dtype,
302278
data: hps.SearchStrategy[hps.DataObject],
303279
):
304280
self.skipTest("interpret mode with dynamic grids is unsupported")
305281
self.gmm_test(
306282
m,
307283
k,
308284
n,
309-
lhs_dtype,
310-
rhs_dtype,
311-
out_dtype,
312-
transpose_rhs=False,
313285
data=data,
314286
interpret=True,
315287
)
316288

317-
@parameterized.parameters(*with_dtype_arguments(GROUPED_MATMUL_TESTS))
289+
@parameterized.parameters(*GROUPED_MATMUL_TESTS)
318290
@hp.given(hps.data())
319291
def test_gmm_sharded_groups(
320292
self,
321293
m: int,
322294
k: int,
323295
n: int,
324-
lhs_dtype: jnp.dtype,
325-
rhs_dtype: jnp.dtype,
326-
out_dtype: jnp.dtype,
327296
data: hps.SearchStrategy[hps.DataObject],
328297
):
329298
seed = data.draw(seed_strategy())
330299
num_groups, group_stride = data.draw(group_strategy())
300+
lhs_dtype, rhs_dtype, out_dtype = [
301+
data.draw(hps.sampled_from([jnp.float32, jnp.bfloat16]))
302+
for _ in range(3)
303+
]
331304

332305
key = jax.random.key(seed)
333306
k1, k2 = jax.random.split(key, 2)

0 commit comments

Comments
 (0)