1919specifications. It supports mixed prefill and decoding, enhancing throughput
2020during inference.
2121"""
22-
2322import functools
2423import jax
2524from jax import lax
@@ -81,6 +80,7 @@ def ref_ragged_paged_attention(
8180 num_seqs : jax .Array , # i32[1],
8281 * ,
8382 sm_scale : float = 1.0 ,
83+ sliding_window : int | None = None ,
8484 mask_value : float = DEFAULT_MASK_VALUE ,
8585):
8686 _ , _ , num_kv_heads , head_dim = k_pages .shape
@@ -105,7 +105,10 @@ def ref_ragged_paged_attention(
105105 jnp .int32 , attn .shape , 1
106106 )
107107 kv_span = jax .lax .broadcasted_iota (jnp .int32 , attn .shape , 2 )
108- attn += jnp .where (q_span < kv_span , mask_value , 0.0 )
108+ mask = q_span < kv_span
109+ if sliding_window is not None :
110+ mask = jnp .logical_or (mask , q_span - sliding_window >= kv_span )
111+ attn += jnp .where (mask , mask_value , 0.0 )
109112 attn = jax .nn .softmax (attn , axis = - 1 ).astype (v .dtype )
110113 out = jnp .einsum ("hqk,khd->qhd" , attn , v ).astype (queries .dtype )
111114 outputs .append (out )
@@ -122,6 +125,7 @@ def validate_inputs_on_runtime(
122125 page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
123126 cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
124127 num_seqs , # i32[1]
128+ sliding_window : int | None = None ,
125129):
126130 check_inputs_shapes (
127131 q , k_pages , v_pages , kv_lens , page_indices , cu_q_lens , num_seqs
@@ -150,6 +154,8 @@ def validate_inputs_on_runtime(
150154 raise ValueError (
151155 f"{ q_len = } must be less or equal to { kv_len = } at sequence { i } ."
152156 )
157+ if sliding_window is not None and sliding_window <= 0 :
158+ raise ValueError (f"{ sliding_window = } must be positive." )
153159
154160
155161# Expect to run these checks during compile time.
@@ -221,7 +227,8 @@ def ragged_paged_attention_kernel(
221227 m_ref , # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
222228 * ,
223229 sm_scale : float ,
224- mask_value : float ,
230+ sliding_window : int | None = None ,
231+ mask_value : float = DEFAULT_MASK_VALUE ,
225232):
226233 num_q_per_blk , num_q_heads_per_blk , head_dim = q_ref .shape
227234 num_seqs = num_seqs_ref [0 ]
@@ -373,7 +380,7 @@ def flash_attention(
373380 def masked_store (ref , val , start , end , group = 1 ):
374381 iota = lax .broadcasted_iota (jnp .int32 , ref .shape , 0 ) // group
375382 mask = jnp .logical_and (iota >= start , iota < end )
376- pl .store (ref , tuple (slice (None ) for _ in ref .shape ), val , mask = mask )
383+ pl .store (ref , idx = tuple (slice (None ) for _ in ref .shape ), val = val , mask = mask )
377384
378385 qk = (
379386 jnp .einsum ("nd,md->nm" , q , k , preferred_element_type = jnp .float32 )
@@ -422,6 +429,9 @@ def init_scratch_ref():
422429 1 ,
423430 )
424431 causal_mask = row_ids < col_ids
432+ if sliding_window is not None :
433+ causal_mask = jnp .logical_or (causal_mask ,
434+ row_ids - sliding_window >= col_ids )
425435 qk += jnp .where (causal_mask , mask_value , 0.0 )
426436 m_curr = jnp .max (qk , axis = 1 , keepdims = True )
427437 s_curr = jnp .exp (qk - m_curr )
@@ -601,6 +611,7 @@ def can_be_xla_fully_tiled(x, packing):
601611 "num_kv_pages_per_block" ,
602612 "num_queries_per_block" ,
603613 "vmem_limit_bytes" ,
614+ "sliding_window" ,
604615 ],
605616)
606617def ragged_paged_attention (
@@ -614,6 +625,7 @@ def ragged_paged_attention(
614625 num_seqs : jax .Array , # i32[1]
615626 * ,
616627 sm_scale : float = 1.0 ,
628+ sliding_window : int | None = None ,
617629 mask_value : float = DEFAULT_MASK_VALUE ,
618630 num_kv_pages_per_block : int = 16 ,
619631 num_queries_per_block : int = 128 ,
@@ -632,6 +644,7 @@ def ragged_paged_attention(
632644 kv_lens, only the first num_seqs+1 values are valid.
633645 num_seqs: the dynamic number of sequences.
634646 sm_scale: the softmax scale which will be applied to the Q@K^T.
647+ sliding_window: the sliding window size for the attention.
635648 mask_value: mask value for causal mask.
636649 num_kv_pages_per_block: number of kv pages to be processed in one flash
637650 attention block in the pallas kernel.
@@ -705,6 +718,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
705718 functools .partial (
706719 ragged_paged_attention_kernel ,
707720 sm_scale = sm_scale ,
721+ sliding_window = sliding_window ,
708722 mask_value = mask_value ,
709723 ),
710724 grid_spec = pltpu .PrefetchScalarGridSpec (
@@ -724,6 +738,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
724738 out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = jnp .float32 ),
725739 name = "ragged_paged_attention_kernel" ,
726740 )
741+
727742 # TODO(jevinjiang): Use f32 acc scratch for output! So we only need
728743 # to transfer output with desired dtype back to HBM.
729744 return kernel (* scalar_prefetches , q , k_pages , v_pages ).astype (q .dtype )
0 commit comments