Skip to content

Commit 8a31619

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:MGPU] Make the shapes from the attention example more interesting
This bumps up the number of heads and removes the batch_size=2 case: it's very similar to batch_size=1 and doubles the script runtime. We also don't do full autotuning by default since the largest size that works usually performs the best. PiperOrigin-RevId: 701976192
1 parent aff7714 commit 8a31619

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

jax/experimental/pallas/ops/gpu/attention_mgpu.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,26 @@ def attention_reference(q, k, v):
249249

250250

251251
def main(unused_argv):
252-
num_q_heads = 1
253-
num_kv_heads = 1
254-
problem_it = itertools.product((1, 2), (4096, 32768,), (64, 128, 256,))
252+
num_q_heads = 16
253+
num_kv_heads = 16
254+
problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,))
255255
for batch_size, seq_len, head_dim in problem_it:
256256
q_seq_len = kv_seq_len = seq_len
257257
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
258258
f"{num_q_heads=:<4} {head_dim=:<6} ====")
259-
param_it = itertools.product((64,), (64, 128, 256))
260-
best = None
261259
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
262260
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
263261
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
264262
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
265-
for block_q, block_kv in param_it:
263+
block_q = 64
264+
best = None
265+
for block_kv in (256, 128, 64):
266266
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
267267
try:
268268
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
269-
out_ref = attention_reference(q, k, v)
270-
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
269+
if seq_len < 32768:
270+
out_ref = attention_reference(q, k, v)
271+
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
271272
except ValueError as e:
272273
if "exceeds available shared memory" in e.args[0]:
273274
continue
@@ -285,6 +286,7 @@ def main(unused_argv):
285286
)
286287
if best is None or runtime_us < best[0]:
287288
best = (runtime_us, achieved_tc_util)
289+
break # Remove this for full autotuning.
288290
if best is not None:
289291
print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization")
290292

0 commit comments

Comments
 (0)