Skip to content

Commit 84ec21e

Browse files
Add sliding window support to the ragged paged attention.
PiperOrigin-RevId: 738457532
1 parent 918192f commit 84ec21e

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

jax/experimental/pallas/ops/tpu/ragged_paged_attention.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
specifications. It supports mixed prefill and decoding, enhancing throughput
2020
during inference.
2121
"""
22-
2322
import functools
2423
import jax
2524
from jax import lax
@@ -81,6 +80,7 @@ def ref_ragged_paged_attention(
8180
num_seqs: jax.Array, # i32[1],
8281
*,
8382
sm_scale: float = 1.0,
83+
sliding_window: int | None = None,
8484
mask_value: float = DEFAULT_MASK_VALUE,
8585
):
8686
_, _, num_kv_heads, head_dim = k_pages.shape
@@ -105,7 +105,10 @@ def ref_ragged_paged_attention(
105105
jnp.int32, attn.shape, 1
106106
)
107107
kv_span = jax.lax.broadcasted_iota(jnp.int32, attn.shape, 2)
108-
attn += jnp.where(q_span < kv_span, mask_value, 0.0)
108+
mask = q_span < kv_span
109+
if sliding_window is not None:
110+
mask = jnp.logical_or(mask, q_span - sliding_window >= kv_span)
111+
attn += jnp.where(mask, mask_value, 0.0)
109112
attn = jax.nn.softmax(attn, axis=-1).astype(v.dtype)
110113
out = jnp.einsum("hqk,khd->qhd", attn, v).astype(queries.dtype)
111114
outputs.append(out)
@@ -122,6 +125,7 @@ def validate_inputs_on_runtime(
122125
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
123126
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
124127
num_seqs, # i32[1]
128+
sliding_window: int | None = None,
125129
):
126130
check_inputs_shapes(
127131
q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens, num_seqs
@@ -150,6 +154,8 @@ def validate_inputs_on_runtime(
150154
raise ValueError(
151155
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}."
152156
)
157+
if sliding_window is not None and sliding_window <= 0:
158+
raise ValueError(f"{sliding_window=} must be positive.")
153159

154160

155161
# Expect to run these checks during compile time.
@@ -221,7 +227,8 @@ def ragged_paged_attention_kernel(
221227
m_ref, # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
222228
*,
223229
sm_scale: float,
224-
mask_value: float,
230+
sliding_window: int | None = None,
231+
mask_value: float = DEFAULT_MASK_VALUE,
225232
):
226233
num_q_per_blk, num_q_heads_per_blk, head_dim = q_ref.shape
227234
num_seqs = num_seqs_ref[0]
@@ -373,7 +380,7 @@ def flash_attention(
373380
def masked_store(ref, val, start, end, group=1):
374381
iota = lax.broadcasted_iota(jnp.int32, ref.shape, 0) // group
375382
mask = jnp.logical_and(iota >= start, iota < end)
376-
pl.store(ref, tuple(slice(None) for _ in ref.shape), val, mask=mask)
383+
pl.store(ref, idx=tuple(slice(None) for _ in ref.shape), val=val, mask=mask)
377384

378385
qk = (
379386
jnp.einsum("nd,md->nm", q, k, preferred_element_type=jnp.float32)
@@ -422,6 +429,9 @@ def init_scratch_ref():
422429
1,
423430
)
424431
causal_mask = row_ids < col_ids
432+
if sliding_window is not None:
433+
causal_mask = jnp.logical_or(causal_mask,
434+
row_ids - sliding_window>=col_ids)
425435
qk += jnp.where(causal_mask, mask_value, 0.0)
426436
m_curr = jnp.max(qk, axis=1, keepdims=True)
427437
s_curr = jnp.exp(qk - m_curr)
@@ -601,6 +611,7 @@ def can_be_xla_fully_tiled(x, packing):
601611
"num_kv_pages_per_block",
602612
"num_queries_per_block",
603613
"vmem_limit_bytes",
614+
"sliding_window",
604615
],
605616
)
606617
def ragged_paged_attention(
@@ -614,6 +625,7 @@ def ragged_paged_attention(
614625
num_seqs: jax.Array, # i32[1]
615626
*,
616627
sm_scale: float = 1.0,
628+
sliding_window: int | None = None,
617629
mask_value: float = DEFAULT_MASK_VALUE,
618630
num_kv_pages_per_block: int = 16,
619631
num_queries_per_block: int = 128,
@@ -632,6 +644,7 @@ def ragged_paged_attention(
632644
kv_lens, only the first num_seqs+1 values are valid.
633645
num_seqs: the dynamic number of sequences.
634646
sm_scale: the softmax scale which will be applied to the Q@K^T.
647+
sliding_window: the sliding window size for the attention.
635648
mask_value: mask value for causal mask.
636649
num_kv_pages_per_block: number of kv pages to be processed in one flash
637650
attention block in the pallas kernel.
@@ -705,6 +718,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
705718
functools.partial(
706719
ragged_paged_attention_kernel,
707720
sm_scale=sm_scale,
721+
sliding_window=sliding_window,
708722
mask_value=mask_value,
709723
),
710724
grid_spec=pltpu.PrefetchScalarGridSpec(
@@ -724,6 +738,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
724738
out_shape=jax.ShapeDtypeStruct(shape=q.shape, dtype=jnp.float32),
725739
name="ragged_paged_attention_kernel",
726740
)
741+
727742
# TODO(jevinjiang): Use f32 acc scratch for output! So we only need
728743
# to transfer output with desired dtype back to HBM.
729744
return kernel(*scalar_prefetches, q, k_pages, v_pages).astype(q.dtype)

tests/pallas/tpu_ragged_paged_attention_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import random
16+
1617
from absl.testing import absltest
1718
from absl.testing import parameterized
1819
import jax
@@ -50,6 +51,7 @@ def _test_ragged_paged_attention(
5051
vmem_limit_bytes=32 * 1024 * 1024,
5152
max_num_batched_tokens=512,
5253
max_num_seq=8,
54+
sliding_window: int | None = None,
5355
):
5456
if not jtu.is_device_tpu_at_least(version=4):
5557
self.skipTest("Expect TPUv4+")
@@ -101,8 +103,10 @@ def _test_ragged_paged_attention(
101103
page_indices,
102104
cu_q_lens,
103105
num_seqs,
106+
sliding_window=sliding_window,
104107
)
105108

109+
actual_num_q_tokens = cu_q_lens[num_seqs[0]]
106110
output = ragged_paged_attention(
107111
q,
108112
k_pages,
@@ -114,7 +118,8 @@ def _test_ragged_paged_attention(
114118
num_kv_pages_per_block=num_kv_pages_per_block,
115119
num_queries_per_block=num_queries_per_block,
116120
vmem_limit_bytes=vmem_limit_bytes,
117-
)[: cu_q_lens[num_seqs[0]]]
121+
sliding_window=sliding_window,
122+
)[: actual_num_q_tokens]
118123

119124
expected = ref_ragged_paged_attention(
120125
q,
@@ -124,6 +129,7 @@ def _test_ragged_paged_attention(
124129
page_indices,
125130
cu_q_lens,
126131
num_seqs=num_seqs,
132+
sliding_window=sliding_window,
127133
)
128134
tols = {
129135
"float32": 0.15,
@@ -266,6 +272,7 @@ def test_ragged_paged_attention_mixed(self, dtype):
266272
dtype=[jnp.float32, jnp.bfloat16],
267273
num_kv_pages_per_block=[4, 8],
268274
num_queries_per_block=[32, 64],
275+
sliding_window=[None, 5, 128],
269276
)
270277
def test_ragged_paged_attention_complex(
271278
self,
@@ -274,6 +281,7 @@ def test_ragged_paged_attention_complex(
274281
dtype,
275282
num_kv_pages_per_block,
276283
num_queries_per_block,
284+
sliding_window: int | None,
277285
):
278286
seq_lens = []
279287
for _ in range(num_seqs):
@@ -294,8 +302,38 @@ def test_ragged_paged_attention_complex(
294302
num_pages,
295303
num_kv_pages_per_block=num_kv_pages_per_block,
296304
num_queries_per_block=num_queries_per_block,
305+
sliding_window=sliding_window,
297306
)
298307

308+
def test_ragged_paged_attention_sliding_window_should_be_positive(self):
309+
dtype=jnp.float32
310+
seq_lens = [(192, 328), (128, 180), (64, 255)]
311+
num_heads = (32, 8)
312+
head_dim = 128
313+
page_size = 16
314+
num_pages = 1000
315+
316+
with self.assertRaisesRegex(ValueError, "must be positive"):
317+
self._test_ragged_paged_attention(
318+
seq_lens,
319+
num_heads,
320+
head_dim,
321+
page_size,
322+
dtype,
323+
num_pages,
324+
sliding_window=0,
325+
)
326+
327+
with self.assertRaisesRegex(ValueError, "must be positive"):
328+
self._test_ragged_paged_attention(
329+
seq_lens,
330+
num_heads,
331+
head_dim,
332+
page_size,
333+
dtype,
334+
num_pages,
335+
sliding_window=-1,
336+
)
299337

300338
if __name__ == "__main__":
301339
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)