Skip to content

Commit 7b45552

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[ragged-paged-attn] Unify kv strided load to one.
I expected Mosaic can canonicalize 2 same strided loads to one but it did not. (We will fix this in Mosaic). For now, manually converting to one strided load boosts 20~35% speedup in both v6e and v5e single chip for Meta-Llama-3-8B. PiperOrigin-RevId: 745294058
1 parent 8301c30 commit 7b45552

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def ref_ragged_paged_attention(
129129
return jnp.concatenate(outputs, axis=0)
130130

131131

132-
# Expect to run these checkes during runtime.
132+
# Expect to run these checks during runtime.
133133
def validate_dynamic_inputs(
134134
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
135135
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
@@ -283,19 +283,19 @@ def create_kv_async_copy_descriptors(
283283
# 2. Support arbitrary strided load/store for any last dimension.
284284
def strided_load_kv(ref, start, step):
285285
if ref.dtype == jnp.float32:
286-
return ref[start::step, :]
286+
return ref[start::step, :], ref[start + 1 :: step, :]
287287
packing = get_dtype_packing(ref.dtype)
288288
assert ref.dtype == jnp.bfloat16
289289
assert step % packing == 0
290290
b_start = start // packing
291-
b_offset = start % packing
292291
b_step = step // packing
293-
b_ref = ref.bitcast(jnp.int32)
292+
b_ref = ref.bitcast(jnp.uint32)
294293
b = b_ref[b_start::b_step, :]
295-
bw = 32 // packing
296-
b = jnp.right_shift(b, bw * b_offset)
297-
b = jnp.left_shift(b, bw * (packing - 1))
298-
return pltpu.bitcast(b, jnp.float32).astype(jnp.bfloat16)
294+
bk = b << 16
295+
bv = b & jnp.uint32(0xffff0000)
296+
k = pltpu.bitcast(bk, jnp.float32).astype(jnp.bfloat16)
297+
v = pltpu.bitcast(bv, jnp.float32).astype(jnp.bfloat16)
298+
return k, v
299299

300300
def fold_on_2nd_minor(vec):
301301
assert vec.dtype == jnp.bfloat16 or vec.dtype == jnp.float32
@@ -537,12 +537,9 @@ def prefetch_next_kv_blk():
537537
q = fold_on_2nd_minor(
538538
q_ref[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :]
539539
)
540-
k = strided_load_kv(
540+
k, v = strided_load_kv(
541541
kv_ref, kv_head_idx * 2, num_combined_kv_heads_per_blk
542542
)
543-
v = strided_load_kv(
544-
kv_ref, kv_head_idx * 2 + 1, num_combined_kv_heads_per_blk
545-
)
546543
flash_attention(
547544
q,
548545
k,

0 commit comments

Comments
 (0)