Skip to content

Commit b719ac0

Browse files
Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM.
We use f32 as the dtype inside the kernel. Before we write the result from vmem to hbm, we convert to the desired dtype (eg bf16). So we can save memory bandwidth. Also, made minor change by checking sliding window and logit soft capping in the function that checks the static value. PiperOrigin-RevId: 741660728
1 parent 2d63b6e commit b719ac0

File tree

2 files changed

+35
-27
lines changed

2 files changed

+35
-27
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def ref_ragged_paged_attention(
8383
soft_cap: float | None = None,
8484
mask_value: float | None = DEFAULT_MASK_VALUE,
8585
):
86-
check_inputs_shapes(
87-
queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs
86+
validate_static_inputs(
87+
queries, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap
8888
)
8989
if mask_value is None:
9090
mask_value = DEFAULT_MASK_VALUE
@@ -130,7 +130,7 @@ def ref_ragged_paged_attention(
130130

131131

132132
# Expect to run these checkes during runtime.
133-
def validate_inputs_on_runtime(
133+
def validate_dynamic_inputs(
134134
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
135135
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
136136
kv_lens: jax.Array, # i32[max_num_seqs]
@@ -140,7 +140,7 @@ def validate_inputs_on_runtime(
140140
sliding_window: int | None = None,
141141
soft_cap: float | None = None,
142142
):
143-
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
143+
validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap)
144144
max_num_batched_tokens = q.shape[0]
145145
page_size = kv_pages.shape[1]
146146
max_num_seqs, pages_per_seq = page_indices.shape
@@ -165,20 +165,18 @@ def validate_inputs_on_runtime(
165165
raise ValueError(
166166
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
167167
)
168-
if sliding_window is not None and sliding_window <= 0:
169-
raise ValueError(f"{sliding_window=} must be positive.")
170-
if soft_cap is not None and soft_cap == 0.0:
171-
raise ValueError(f"{soft_cap=} must not be 0.0.")
172168

173169

174170
# Expect to run these checks during compile time.
175-
def check_inputs_shapes(
171+
def validate_static_inputs(
176172
q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim]
177173
kv_pages: jax.Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
178174
kv_lens: jax.Array, # i32[max_num_seqs]
179175
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
180176
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
181177
num_seqs, # i32[1]
178+
sliding_window: int | None = None,
179+
soft_cap: float | None = None,
182180
):
183181
_, num_q_heads, head_dim = q.shape
184182
_, _, num_combined_kv_heads, head_dim_k = kv_pages.shape
@@ -213,6 +211,10 @@ def check_inputs_shapes(
213211
)
214212
if num_q_heads % num_kv_heads != 0:
215213
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
214+
if sliding_window is not None and sliding_window <= 0:
215+
raise ValueError(f"{sliding_window=} must be positive.")
216+
if soft_cap is not None and soft_cap == 0.0:
217+
raise ValueError(f"{soft_cap=} must not be 0.0.")
216218

217219

218220
def ragged_paged_attention_kernel(
@@ -233,6 +235,7 @@ def ragged_paged_attention_kernel(
233235
sems, # [2, 2]
234236
l_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
235237
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
238+
acc_ref, # [num_q_per_blk, num_q_heads_per_blk, head_dim]
236239
*,
237240
sm_scale: float,
238241
sliding_window: int | None = None,
@@ -357,7 +360,7 @@ def flash_attention(
357360
v, # [num_kv_per_blk, head_dim]
358361
head_l_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
359362
head_m_ref, # [num_q_per_blk * num_q_heads_per_kv_head, 128]
360-
head_o_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
363+
head_acc_ref, # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
361364
*,
362365
kv_blk_idx,
363366
):
@@ -378,7 +381,7 @@ def flash_attention(
378381
num_q_per_blk * num_q_heads_per_kv_head,
379382
128,
380383
)
381-
assert head_o_ref.shape == (
384+
assert head_acc_ref.shape == (
382385
num_q_per_blk,
383386
num_q_heads_per_kv_head,
384387
head_dim,
@@ -414,8 +417,8 @@ def init_scratch_ref():
414417
num_q_heads_per_kv_head,
415418
)
416419
masked_store(
417-
head_o_ref,
418-
jnp.zeros_like(head_o_ref),
420+
head_acc_ref,
421+
jnp.zeros_like(head_acc_ref),
419422
store_start,
420423
store_end,
421424
)
@@ -481,17 +484,17 @@ def broadcast_to_shape(arr, shape):
481484
[arr for _ in range(shape[1] // arr.shape[1])], axis=1
482485
)
483486

484-
o_curr = head_o_ref[...].reshape(-1, head_dim)
487+
o_curr = head_acc_ref[...].reshape(-1, head_dim)
485488
l_alpha = broadcast_to_shape(l_alpha, qkv.shape)
486489
beta = broadcast_to_shape(beta, qkv.shape)
487490
l_next_safe = broadcast_to_shape(l_next_safe, qkv.shape)
488491
out = lax.div(
489492
l_alpha * o_curr + beta * qkv,
490493
l_next_safe,
491-
).astype(head_o_ref.dtype)
494+
)
492495
masked_store(
493-
head_o_ref,
494-
out.reshape(head_o_ref.shape),
496+
head_acc_ref,
497+
out.reshape(head_acc_ref.shape),
495498
store_start,
496499
store_end,
497500
)
@@ -544,7 +547,7 @@ def prefetch_next_kv_blk():
544547
v,
545548
l_ref.at[kv_head_idx],
546549
m_ref.at[kv_head_idx],
547-
o_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :],
550+
acc_ref.at[:, q_head_idx : q_head_idx + num_q_heads_per_kv_head, :],
548551
kv_blk_idx=kv_blk_idx,
549552
)
550553
return kv_blk_idx + 1, next_buf_idx
@@ -566,6 +569,7 @@ def prefetch_next_kv_blk():
566569
# Reset seq_idx for next kv_heads_blk if run out of seqs!
567570
seq_buf_idx_ref[0] = lax.select(seq_idx < num_seqs, seq_idx, 0)
568571
seq_buf_idx_ref[1] = buf_idx
572+
o_ref[...] = acc_ref[...].astype(q_ref.dtype)
569573

570574

571575
def cdiv(a, b):
@@ -662,6 +666,7 @@ def ragged_paged_attention(
662666
num_seqs: the dynamic number of sequences.
663667
sm_scale: the softmax scale which will be applied to the Q@K^T.
664668
sliding_window: the sliding window size for the attention.
669+
soft_cap: the logit soft cap for the attention.
665670
mask_value: mask value for causal mask.
666671
num_kv_pages_per_block: number of kv pages to be processed in one flash
667672
attention block in the pallas kernel.
@@ -672,7 +677,7 @@ def ragged_paged_attention(
672677
Returns:
673678
The output of the attention.
674679
"""
675-
check_inputs_shapes(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs)
680+
validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, num_seqs, sliding_window, soft_cap)
676681
if mask_value is None:
677682
mask_value = DEFAULT_MASK_VALUE
678683
_, num_q_heads, head_dim = q.shape
@@ -710,6 +715,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
710715
(num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128),
711716
jnp.float32,
712717
)
718+
acc_scratch = pltpu.VMEM(
719+
(num_q_per_blk, num_q_heads_per_blk, head_dim),
720+
jnp.float32,
721+
)
713722
double_buf_scratch = pltpu.VMEM(
714723
(
715724
2, # For double buffering during DMA copies.
@@ -725,6 +734,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
725734
pltpu.SemaphoreType.DMA((2,)), # Semaphores for double buffers.
726735
lm_scratch, # l_ref
727736
lm_scratch, # m_ref
737+
acc_scratch,
728738
]
729739
scalar_prefetches = (
730740
kv_lens,
@@ -755,10 +765,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
755765
),
756766
vmem_limit_bytes=vmem_limit_bytes,
757767
),
758-
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32),
768+
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype),
759769
name="ragged_paged_attention_kernel",
760770
)
761771

762-
# TODO(jevinjiang): Use f32 acc scratch for output! So we only need
763-
# to transfer output with desired dtype back to HBM.
764-
return kernel(*scalar_prefetches, q, kv_pages).astype(q.dtype)
772+
return kernel(*scalar_prefetches, q, kv_pages)

tests/pallas/tpu_ragged_paged_attention_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from jax.experimental.pallas.ops.tpu.ragged_paged_attention import (
2222
ragged_paged_attention,
2323
ref_ragged_paged_attention,
24-
validate_inputs_on_runtime,
24+
validate_dynamic_inputs,
2525
)
2626
import jax.numpy as jnp
2727

@@ -91,15 +91,15 @@ def _test_ragged_paged_attention(
9191

9292
num_seqs = jnp.array([len(seq_lens)], dtype=jnp.int32)
9393

94-
validate_inputs_on_runtime(
94+
validate_dynamic_inputs(
9595
q,
9696
kv_pages,
9797
kv_lens,
9898
page_indices,
9999
cu_q_lens,
100100
num_seqs,
101-
sliding_window=sliding_window,
102-
soft_cap=soft_cap,
101+
sliding_window,
102+
soft_cap,
103103
)
104104

105105
actual_num_q_tokens = cu_q_lens[num_seqs[0]]

0 commit comments

Comments
 (0)