Skip to content

Commit eff612a

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
Fix the assumption that pages_per_seq is already a multiple of num_kv_pages_per_blk.
PiperOrigin-RevId: 735851301
1 parent 0db14aa commit eff612a

File tree

2 files changed

+18
-14
lines changed

2 files changed

+18
-14
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,22 @@ def __init__(
4343
):
4444
self._vmem_buf = vmem_buf
4545
seq_id, kv_pages_start = offset
46-
self._async_copies = [
47-
pltpu.make_async_copy(
48-
pages_hbm_ref.at[page_indices_ref[seq_id, kv_pages_start + i]],
49-
vmem_buf.at[i],
50-
sem,
51-
)
52-
for i in range(vmem_buf.shape[0])
53-
]
46+
pages_per_seq = page_indices_ref.shape[1]
47+
self._async_copies = []
48+
# TODO(jevinjiang): Only fetch dynamic shape in need! This will insert
49+
# a bunch of if-ops. Check the performance when we have benchmarking setup.
50+
for i in range(vmem_buf.shape[0]):
51+
page_idx = kv_pages_start + i
52+
page_idx = jax.lax.select(
53+
page_idx < pages_per_seq, page_idx, pages_per_seq - 1
54+
)
55+
self._async_copies.append(
56+
pltpu.make_async_copy(
57+
pages_hbm_ref.at[page_indices_ref[seq_id, page_idx]],
58+
vmem_buf.at[i],
59+
sem,
60+
)
61+
)
5462

5563
def start(self):
5664
"""Starts the async copies."""

tests/pallas/tpu_ragged_paged_attention_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ def _test_ragged_paged_attention(
6464
max_num_seq = max(len(seq_lens), max_num_seq)
6565
max_kv_len = max(kv_lens)
6666
pages_per_seq = ceil_div(max_kv_len, page_size)
67-
pages_per_seq = (
68-
ceil_div(pages_per_seq, num_kv_pages_per_block)
69-
* num_kv_pages_per_block
70-
)
7167
num_q_heads, num_kv_heads = num_heads
7268

7369
cu_q_lens = jnp.array(cu_q_lens, dtype=jnp.int32)
@@ -130,8 +126,8 @@ def _test_ragged_paged_attention(
130126
num_seqs=num_seqs,
131127
)
132128
tols = {
133-
"float32": 1e-1,
134-
"bfloat16": 2e-1,
129+
"float32": 0.15,
130+
"bfloat16": 0.2,
135131
}
136132
tol = tols[jnp.dtype(dtype).name]
137133
self.assertAllClose(output, expected, atol=tol, rtol=tol)

0 commit comments

Comments
 (0)