@@ -143,6 +143,7 @@ def decode_attn_unbatched(
143143 grid : tuple [int , ...] | None ,
144144 interpret : bool ,
145145 debug : bool ,
146+ return_residuals : bool
146147):
147148 num_heads , head_dim = q .shape
148149 k_seq_len , _ = k .shape
@@ -215,7 +216,10 @@ def decode_attn_unbatched(
215216 l_next = (l * correction ).sum (axis = 0 )
216217 eps = jnp .finfo (l_next .dtype ).eps
217218 o = o .sum (axis = 0 ) / (l_next [:, None ].astype (o .dtype ) + eps )
218- return o
219+ if return_residuals :
220+ return o , (l_next , m_next )
221+ else :
222+ return o
219223
220224
221225@functools .partial (
@@ -230,6 +234,7 @@ def decode_attn_unbatched(
230234 "grid" ,
231235 "interpret" ,
232236 "debug" ,
237+ "return_residuals"
233238 ],
234239)
235240def mqa (
@@ -247,6 +252,7 @@ def mqa(
247252 grid : tuple [int , ...] | None = None ,
248253 interpret : bool = False ,
249254 debug : bool = False ,
255+ return_residuals : bool = False
250256):
251257 sm_scale = sm_scale if sm_scale is not None else (1 / math .sqrt (q .shape [- 1 ]))
252258 bs = q .shape [0 ]
@@ -265,6 +271,7 @@ def mqa(
265271 grid = grid ,
266272 interpret = interpret ,
267273 debug = debug ,
274+ return_residuals = return_residuals
268275 )
269276 return jax .vmap (inner )(q , k , v , start_idx , kv_seq_len )
270277
@@ -281,6 +288,7 @@ def mqa(
281288 "grid" ,
282289 "interpret" ,
283290 "debug" ,
291+ "return_residuals"
284292 ],
285293)
286294def gqa (
@@ -298,6 +306,7 @@ def gqa(
298306 grid : tuple [int , ...] | None = None ,
299307 interpret : bool = False ,
300308 debug : bool = False ,
309+ return_residuals : bool = False ,
301310):
302311 sm_scale = sm_scale if sm_scale is not None else (1 / math .sqrt (q .shape [- 1 ]))
303312 batch_size , q_heads , head_dim = q .shape
@@ -331,25 +340,40 @@ def gqa(
331340 grid = grid ,
332341 interpret = interpret ,
333342 debug = debug ,
343+ return_residuals = return_residuals ,
334344 )
335345 with_kv_heads = jax .vmap (inner )
336- o = jax .vmap (with_kv_heads )(q_reshaped , k_transposed , v_transposed ,
337- start_idx , kv_seq_len )
338- return o .reshape (batch_size , q_heads , head_dim )
346+ o , * res = jax .vmap (with_kv_heads )(
347+ q_reshaped , k_transposed , v_transposed , start_idx , kv_seq_len
348+ )
349+ o = o .reshape (batch_size , q_heads , head_dim )
350+ if return_residuals :
351+ l , m = res [0 ]
352+ l = l .reshape (batch_size , q_heads )
353+ m = m .reshape (batch_size , q_heads )
354+ return o , (l , m )
355+ else :
356+ return o
339357
340358
341- @functools .partial (jax .jit , static_argnames = ["sm_scale" ])
359+ @functools .partial (jax .jit , static_argnames = ["sm_scale" , "return_residuals" ])
342360def mqa_reference (
343361 q , # [bs, num_q_heads, head_dim]
344362 k , # [bs, k_seq_len, head_dim]
345363 v , # [bs, k_seq_len, head_dim]
346364 start_idx = None , # [bs]
347365 kv_seq_len = None , # [bs]
348366 sm_scale = None ,
367+ return_residuals = False
349368):
369+ original_dtype = q .dtype
370+ q = q .astype (jnp .float32 )
371+ k = k .astype (jnp .float32 )
350372 bs = q .shape [0 ]
351373 sm_scale = sm_scale if sm_scale is not None else (1 / math .sqrt (q .shape [- 1 ]))
352374 logits = jnp .einsum ("bnd,bsd->bns" , q , k ).astype (jnp .float32 )
375+ if sm_scale is not None and sm_scale != 1.0 :
376+ logits = logits * sm_scale
353377 if start_idx is not None or kv_seq_len is not None :
354378 start_idx = jnp .broadcast_to (0 if start_idx is None else start_idx , (bs ,))
355379 kv_seq_len = jnp .broadcast_to (k .shape [1 ] if kv_seq_len is None
@@ -358,8 +382,17 @@ def mqa_reference(
358382 & (jnp .arange (k .shape [1 ])[None , :] < kv_seq_len [:, None ]))
359383 mask = mask [:, None , :]
360384 logits = logits + (~ mask ) * (0.7 * jnp .finfo (logits .dtype ).min )
361- weights = jax .nn .softmax (logits * sm_scale ).astype (q .dtype )
362- return jnp .einsum ("bns,bsd->bnd" , weights , v )
385+
386+ m = logits .max (axis = - 1 )
387+ s = jnp .exp (logits - m [..., None ])
388+ l = s .sum (axis = - 1 )
389+ s = s / l [..., None ]
390+ o = jnp .einsum ("bns,bsd->bnd" , s , v ).astype (original_dtype )
391+
392+ if return_residuals :
393+ return o , (l , m )
394+ else :
395+ return o
363396
364397
365398@functools .partial (jax .jit , static_argnames = ["sm_scale" ])
@@ -387,15 +420,19 @@ def mha_reference(
387420 return jnp .einsum ("bns,bsnd->bnd" , weights , v )
388421
389422
390- @functools .partial (jax .jit , static_argnames = ["sm_scale" ])
423+ @functools .partial (jax .jit , static_argnames = ["sm_scale" , "return_residuals" ])
391424def gqa_reference (
392425 q , # [bs, num_q_heads, head_dim]
393426 k , # [bs, k_seq_len, num_k_heads, head_dim]
394427 v , # [bs, k_seq_len, num_v_heads, head_dim]
395428 start_idx = None , # [bs]
396429 kv_seq_len = None , # [bs]
397430 sm_scale = None ,
431+ return_residuals = False
398432):
433+ original_dtype = q .dtype
434+ q = q .astype (jnp .float32 )
435+ k = k .astype (jnp .float32 )
399436 sm_scale = sm_scale if sm_scale is not None else (1 / math .sqrt (q .shape [- 1 ]))
400437 bs , num_q_heads , head_dim = q .shape
401438 num_kv_heads = k .shape [2 ]
@@ -412,6 +449,8 @@ def gqa_reference(
412449 logits = jnp .einsum ("bkgd,bksd->bkgs" , q_reshaped , k_transposed ).astype (
413450 jnp .float32
414451 )
452+ if sm_scale is not None and sm_scale != 1.0 :
453+ logits = logits * sm_scale
415454 if start_idx is not None or kv_seq_len is not None :
416455 start_idx = jnp .broadcast_to (0 if start_idx is None else start_idx , (bs ,))
417456 kv_seq_len = jnp .broadcast_to (k .shape [1 ] if kv_seq_len is None
@@ -420,7 +459,17 @@ def gqa_reference(
420459 & (jnp .arange (k .shape [1 ])[None , :] < kv_seq_len [:, None ]))
421460 mask = mask [:, None , None , :]
422461 logits = logits + (~ mask ) * (0.7 * jnp .finfo (logits .dtype ).min )
423- weights = jax .nn .softmax (logits * sm_scale ).astype (q .dtype )
424- o = jnp .einsum ("bkgs,bksd->bkgd" , weights , v_transposed )
462+
463+ m = logits .max (axis = - 1 )
464+ s = jnp .exp (logits - m [..., None ])
465+ l = s .sum (axis = - 1 )
466+ s = s / l [..., None ]
467+ o = jnp .einsum ("bkgs,bksd->bkgd" , s , v_transposed ).astype (original_dtype )
425468 o = o .reshape (bs , num_q_heads , head_dim )
426- return o
469+
470+ if return_residuals :
471+ l = l .reshape (bs , num_q_heads )
472+ m = m .reshape (bs , num_q_heads )
473+ return o , (l , m )
474+ else :
475+ return o
0 commit comments