Skip to content

Commit db97d7a

Browse files
Merge pull request jax-ml#25199 from Rifur13:save_residuals
PiperOrigin-RevId: 702824842
2 parents 222b2e7 + a4e742d commit db97d7a

File tree

2 files changed

+123
-25
lines changed

2 files changed

+123
-25
lines changed

jax/experimental/pallas/ops/gpu/decode_attention.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def decode_attn_unbatched(
143143
grid: tuple[int, ...] | None,
144144
interpret: bool,
145145
debug: bool,
146+
return_residuals: bool
146147
):
147148
num_heads, head_dim = q.shape
148149
k_seq_len, _ = k.shape
@@ -215,7 +216,10 @@ def decode_attn_unbatched(
215216
l_next = (l * correction).sum(axis=0)
216217
eps = jnp.finfo(l_next.dtype).eps
217218
o = o.sum(axis=0) / (l_next[:, None].astype(o.dtype) + eps)
218-
return o
219+
if return_residuals:
220+
return o, (l_next, m_next)
221+
else:
222+
return o
219223

220224

221225
@functools.partial(
@@ -230,6 +234,7 @@ def decode_attn_unbatched(
230234
"grid",
231235
"interpret",
232236
"debug",
237+
"return_residuals"
233238
],
234239
)
235240
def mqa(
@@ -247,6 +252,7 @@ def mqa(
247252
grid: tuple[int, ...] | None = None,
248253
interpret: bool = False,
249254
debug: bool = False,
255+
return_residuals: bool = False
250256
):
251257
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
252258
bs = q.shape[0]
@@ -265,6 +271,7 @@ def mqa(
265271
grid=grid,
266272
interpret=interpret,
267273
debug=debug,
274+
return_residuals=return_residuals
268275
)
269276
return jax.vmap(inner)(q, k, v, start_idx, kv_seq_len)
270277

@@ -281,6 +288,7 @@ def mqa(
281288
"grid",
282289
"interpret",
283290
"debug",
291+
"return_residuals"
284292
],
285293
)
286294
def gqa(
@@ -298,6 +306,7 @@ def gqa(
298306
grid: tuple[int, ...] | None = None,
299307
interpret: bool = False,
300308
debug: bool = False,
309+
return_residuals: bool = False,
301310
):
302311
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
303312
batch_size, q_heads, head_dim = q.shape
@@ -331,25 +340,40 @@ def gqa(
331340
grid=grid,
332341
interpret=interpret,
333342
debug=debug,
343+
return_residuals=return_residuals,
334344
)
335345
with_kv_heads = jax.vmap(inner)
336-
o = jax.vmap(with_kv_heads)(q_reshaped, k_transposed, v_transposed,
337-
start_idx, kv_seq_len)
338-
return o.reshape(batch_size, q_heads, head_dim)
346+
o, *res = jax.vmap(with_kv_heads)(
347+
q_reshaped, k_transposed, v_transposed, start_idx, kv_seq_len
348+
)
349+
o = o.reshape(batch_size, q_heads, head_dim)
350+
if return_residuals:
351+
l, m = res[0]
352+
l = l.reshape(batch_size, q_heads)
353+
m = m.reshape(batch_size, q_heads)
354+
return o, (l, m)
355+
else:
356+
return o
339357

340358

341-
@functools.partial(jax.jit, static_argnames=["sm_scale"])
359+
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
342360
def mqa_reference(
343361
q, # [bs, num_q_heads, head_dim]
344362
k, # [bs, k_seq_len, head_dim]
345363
v, # [bs, k_seq_len, head_dim]
346364
start_idx=None, # [bs]
347365
kv_seq_len=None, # [bs]
348366
sm_scale=None,
367+
return_residuals=False
349368
):
369+
original_dtype = q.dtype
370+
q = q.astype(jnp.float32)
371+
k = k.astype(jnp.float32)
350372
bs = q.shape[0]
351373
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
352374
logits = jnp.einsum("bnd,bsd->bns", q, k).astype(jnp.float32)
375+
if sm_scale is not None and sm_scale != 1.0:
376+
logits = logits * sm_scale
353377
if start_idx is not None or kv_seq_len is not None:
354378
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
355379
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
@@ -358,8 +382,17 @@ def mqa_reference(
358382
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
359383
mask = mask[:, None, :]
360384
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
361-
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
362-
return jnp.einsum("bns,bsd->bnd", weights, v)
385+
386+
m = logits.max(axis=-1)
387+
s = jnp.exp(logits - m[..., None])
388+
l = s.sum(axis=-1)
389+
s = s / l[..., None]
390+
o = jnp.einsum("bns,bsd->bnd", s, v).astype(original_dtype)
391+
392+
if return_residuals:
393+
return o, (l, m)
394+
else:
395+
return o
363396

364397

365398
@functools.partial(jax.jit, static_argnames=["sm_scale"])
@@ -387,15 +420,19 @@ def mha_reference(
387420
return jnp.einsum("bns,bsnd->bnd", weights, v)
388421

389422

390-
@functools.partial(jax.jit, static_argnames=["sm_scale"])
423+
@functools.partial(jax.jit, static_argnames=["sm_scale", "return_residuals"])
391424
def gqa_reference(
392425
q, # [bs, num_q_heads, head_dim]
393426
k, # [bs, k_seq_len, num_k_heads, head_dim]
394427
v, # [bs, k_seq_len, num_v_heads, head_dim]
395428
start_idx=None, # [bs]
396429
kv_seq_len=None, # [bs]
397430
sm_scale=None,
431+
return_residuals=False
398432
):
433+
original_dtype = q.dtype
434+
q = q.astype(jnp.float32)
435+
k = k.astype(jnp.float32)
399436
sm_scale = sm_scale if sm_scale is not None else (1 / math.sqrt(q.shape[-1]))
400437
bs, num_q_heads, head_dim = q.shape
401438
num_kv_heads = k.shape[2]
@@ -412,6 +449,8 @@ def gqa_reference(
412449
logits = jnp.einsum("bkgd,bksd->bkgs", q_reshaped, k_transposed).astype(
413450
jnp.float32
414451
)
452+
if sm_scale is not None and sm_scale != 1.0:
453+
logits = logits * sm_scale
415454
if start_idx is not None or kv_seq_len is not None:
416455
start_idx = jnp.broadcast_to(0 if start_idx is None else start_idx, (bs,))
417456
kv_seq_len = jnp.broadcast_to(k.shape[1] if kv_seq_len is None
@@ -420,7 +459,17 @@ def gqa_reference(
420459
& (jnp.arange(k.shape[1])[None, :] < kv_seq_len[:, None]))
421460
mask = mask[:, None, None, :]
422461
logits = logits + (~mask) * (0.7 * jnp.finfo(logits.dtype).min)
423-
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
424-
o = jnp.einsum("bkgs,bksd->bkgd", weights, v_transposed)
462+
463+
m = logits.max(axis=-1)
464+
s = jnp.exp(logits - m[..., None])
465+
l = s.sum(axis=-1)
466+
s = s / l[..., None]
467+
o = jnp.einsum("bkgs,bksd->bkgd", s, v_transposed).astype(original_dtype)
425468
o = o.reshape(bs, num_q_heads, head_dim)
426-
return o
469+
470+
if return_residuals:
471+
l = l.reshape(bs, num_q_heads)
472+
m = m.reshape(bs, num_q_heads)
473+
return o, (l, m)
474+
else:
475+
return o

tests/pallas/gpu_attention_test.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from jax import random
2222
from jax._src import config
2323
from jax._src import test_util as jtu
24+
2425
if sys.platform != "win32":
2526
from jax.experimental.pallas.ops.gpu import decode_attention
2627
else:
@@ -48,8 +49,9 @@ def setUp(self):
4849
self.skipTest("On CPU, the test works only in interpret mode")
4950
if jax.config.x64_enabled:
5051
self.skipTest("The test works only in 32-bit")
51-
if (jtu.test_device_matches(["cuda"]) and
52-
not jtu.is_cuda_compute_capability_at_least("8.0")):
52+
if jtu.test_device_matches(
53+
["cuda"]
54+
) and not jtu.is_cuda_compute_capability_at_least("8.0"):
5355
self.skipTest("Only works on GPU with capability >= sm80")
5456
if sys.platform == "win32":
5557
self.skipTest("Only works on non-Windows platforms")
@@ -62,15 +64,18 @@ class DecodeAttentionTest(PallasBaseTest):
6264

6365
@parameterized.named_parameters(*[
6466
(
65-
(f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
66-
f"{start_idx=}_{kv_seq_len=}"),
67+
(
68+
f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{kwargs=}_"
69+
f"{start_idx=}_{kv_seq_len=}_{return_residuals=}"
70+
),
6771
batch_size,
6872
seq_len,
6973
num_heads,
7074
head_dim,
7175
kwargs,
7276
start_idx,
7377
kv_seq_len,
78+
return_residuals,
7479
)
7580
for (
7681
batch_size,
@@ -85,6 +90,7 @@ class DecodeAttentionTest(PallasBaseTest):
8590
]
8691
for start_idx in [None, 123]
8792
for kv_seq_len in [None, 250]
93+
for return_residuals in [False, True]
8894
])
8995
@jax.numpy_dtype_promotion("standard")
9096
def test_mqa(
@@ -96,6 +102,7 @@ def test_mqa(
96102
kwargs,
97103
start_idx,
98104
kv_seq_len,
105+
return_residuals,
99106
):
100107
del kwargs
101108

@@ -104,16 +111,36 @@ def test_mqa(
104111
k = random.normal(k2, (batch_size, seq_len, head_dim), dtype=jnp.float16)
105112
v = random.normal(k3, (batch_size, seq_len, head_dim), dtype=jnp.float16)
106113

107-
o = decode_attention.mqa(q, k, v, start_idx=start_idx,
108-
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
109-
o_ref = decode_attention.mqa_reference(q, k, v, start_idx=start_idx,
110-
kv_seq_len=kv_seq_len)
114+
o, *res = decode_attention.mqa(
115+
q,
116+
k,
117+
v,
118+
start_idx=start_idx,
119+
kv_seq_len=kv_seq_len,
120+
return_residuals=return_residuals,
121+
interpret=self.INTERPRET,
122+
)
123+
o_ref, *res_ref = decode_attention.mqa_reference(
124+
q,
125+
k,
126+
v,
127+
start_idx=start_idx,
128+
kv_seq_len=kv_seq_len,
129+
return_residuals=return_residuals,
130+
)
111131
np.testing.assert_allclose(o, o_ref, atol=0.05)
132+
if return_residuals:
133+
l, m = res[0]
134+
l_ref, m_ref = res_ref[0]
135+
np.testing.assert_allclose(l, l_ref, atol=0.05)
136+
np.testing.assert_allclose(m, m_ref, atol=0.05)
112137

113138
@parameterized.named_parameters(*[
114139
(
115-
(f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
116-
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}"),
140+
(
141+
f"{batch_size=}_{seq_len=}_{num_q_heads=}_{num_kv_heads=}_{head_dim=}"
142+
f"_{kwargs=}_{start_idx=}_{kv_seq_len=}_{return_residuals=}"
143+
),
117144
batch_size,
118145
seq_len,
119146
num_q_heads,
@@ -122,6 +149,7 @@ def test_mqa(
122149
kwargs,
123150
start_idx,
124151
kv_seq_len,
152+
return_residuals,
125153
)
126154
for (
127155
batch_size,
@@ -137,6 +165,7 @@ def test_mqa(
137165
]
138166
for start_idx in [None, 123]
139167
for kv_seq_len in [None, 250]
168+
for return_residuals in [False, True]
140169
])
141170
@jax.numpy_dtype_promotion("standard")
142171
def test_gqa(
@@ -149,6 +178,7 @@ def test_gqa(
149178
kwargs,
150179
start_idx,
151180
kv_seq_len,
181+
return_residuals,
152182
):
153183
del kwargs
154184

@@ -162,11 +192,30 @@ def test_gqa(
162192
v = random.normal(
163193
k3, (batch_size, seq_len, num_kv_heads, head_dim), dtype=jnp.float16
164194
)
165-
o = decode_attention.gqa(q, k, v, start_idx=start_idx,
166-
kv_seq_len=kv_seq_len, interpret=self.INTERPRET)
167-
o_ref = decode_attention.gqa_reference(q, k, v, start_idx=start_idx,
168-
kv_seq_len=kv_seq_len)
195+
o, *res = decode_attention.gqa(
196+
q,
197+
k,
198+
v,
199+
start_idx=start_idx,
200+
kv_seq_len=kv_seq_len,
201+
return_residuals=return_residuals,
202+
interpret=self.INTERPRET,
203+
)
204+
o_ref, *res_ref = decode_attention.gqa_reference(
205+
q,
206+
k,
207+
v,
208+
start_idx=start_idx,
209+
kv_seq_len=kv_seq_len,
210+
return_residuals=return_residuals,
211+
)
169212
np.testing.assert_allclose(o, o_ref, atol=0.05)
213+
if return_residuals:
214+
l, m = res[0]
215+
l_ref, m_ref = res_ref[0]
216+
np.testing.assert_allclose(l, l_ref, atol=0.05)
217+
np.testing.assert_allclose(m, m_ref, atol=0.05)
218+
170219

171220
class DecodeAttentionInterpretTest(DecodeAttentionTest):
172221
INTERPRET = True

0 commit comments

Comments
 (0)