@@ -83,8 +83,8 @@ def ref_ragged_paged_attention(
8383 soft_cap : float | None = None ,
8484 mask_value : float | None = DEFAULT_MASK_VALUE ,
8585):
86- check_inputs_shapes (
87- queries , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs
86+ validate_static_inputs (
87+ queries , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs , sliding_window , soft_cap
8888 )
8989 if mask_value is None :
9090 mask_value = DEFAULT_MASK_VALUE
@@ -130,7 +130,7 @@ def ref_ragged_paged_attention(
130130
131131
132132# Expect to run these checkes during runtime.
133- def validate_inputs_on_runtime (
133+ def validate_dynamic_inputs (
134134 q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
135135 kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
136136 kv_lens : jax .Array , # i32[max_num_seqs]
@@ -140,7 +140,7 @@ def validate_inputs_on_runtime(
140140 sliding_window : int | None = None ,
141141 soft_cap : float | None = None ,
142142):
143- check_inputs_shapes (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs )
143+ validate_static_inputs (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs , sliding_window , soft_cap )
144144 max_num_batched_tokens = q .shape [0 ]
145145 page_size = kv_pages .shape [1 ]
146146 max_num_seqs , pages_per_seq = page_indices .shape
@@ -165,20 +165,18 @@ def validate_inputs_on_runtime(
165165 raise ValueError (
166166 f"{ q_len = } must be less or equal to { kv_len = } at sequence { i } ."
167167 )
168- if sliding_window is not None and sliding_window <= 0 :
169- raise ValueError (f"{ sliding_window = } must be positive." )
170- if soft_cap is not None and soft_cap == 0.0 :
171- raise ValueError (f"{ soft_cap = } must not be 0.0." )
172168
173169
174170# Expect to run these checks during compile time.
175- def check_inputs_shapes (
171+ def validate_static_inputs (
176172 q : jax .Array , # [max_num_batched_tokens, num_q_heads, head_dim]
177173 kv_pages : jax .Array , # [total_num_pages, page_size, num_combined_kv_heads, head_dim]
178174 kv_lens : jax .Array , # i32[max_num_seqs]
179175 page_indices : jax .Array , # i32[max_num_seqs, pages_per_seq]
180176 cu_q_lens : jax .Array , # i32[max_num_seqs + 1]
181177 num_seqs , # i32[1]
178+ sliding_window : int | None = None ,
179+ soft_cap : float | None = None ,
182180):
183181 _ , num_q_heads , head_dim = q .shape
184182 _ , _ , num_combined_kv_heads , head_dim_k = kv_pages .shape
@@ -213,6 +211,10 @@ def check_inputs_shapes(
213211 )
214212 if num_q_heads % num_kv_heads != 0 :
215213 raise ValueError (f"{ num_q_heads = } must be divisible by { num_kv_heads = } " )
214+ if sliding_window is not None and sliding_window <= 0 :
215+ raise ValueError (f"{ sliding_window = } must be positive." )
216+ if soft_cap is not None and soft_cap == 0.0 :
217+ raise ValueError (f"{ soft_cap = } must not be 0.0." )
216218
217219
218220def ragged_paged_attention_kernel (
@@ -233,6 +235,7 @@ def ragged_paged_attention_kernel(
233235 sems , # [2, 2]
234236 l_ref , # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
235237 m_ref , # [num_kv_heads_per_blk, num_q_per_blk * num_q_heads_per_kv_head, 128]
238+ acc_ref , # [num_q_per_blk, num_q_heads_per_blk, head_dim]
236239 * ,
237240 sm_scale : float ,
238241 sliding_window : int | None = None ,
@@ -357,7 +360,7 @@ def flash_attention(
357360 v , # [num_kv_per_blk, head_dim]
358361 head_l_ref , # [num_q_per_blk * num_q_heads_per_kv_head, 128]
359362 head_m_ref , # [num_q_per_blk * num_q_heads_per_kv_head, 128]
360- head_o_ref , # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
363+ head_acc_ref , # [num_q_per_blk, num_q_heads_per_kv_head, head_dim]
361364 * ,
362365 kv_blk_idx ,
363366 ):
@@ -378,7 +381,7 @@ def flash_attention(
378381 num_q_per_blk * num_q_heads_per_kv_head ,
379382 128 ,
380383 )
381- assert head_o_ref .shape == (
384+ assert head_acc_ref .shape == (
382385 num_q_per_blk ,
383386 num_q_heads_per_kv_head ,
384387 head_dim ,
@@ -414,8 +417,8 @@ def init_scratch_ref():
414417 num_q_heads_per_kv_head ,
415418 )
416419 masked_store (
417- head_o_ref ,
418- jnp .zeros_like (head_o_ref ),
420+ head_acc_ref ,
421+ jnp .zeros_like (head_acc_ref ),
419422 store_start ,
420423 store_end ,
421424 )
@@ -481,17 +484,17 @@ def broadcast_to_shape(arr, shape):
481484 [arr for _ in range (shape [1 ] // arr .shape [1 ])], axis = 1
482485 )
483486
484- o_curr = head_o_ref [...].reshape (- 1 , head_dim )
487+ o_curr = head_acc_ref [...].reshape (- 1 , head_dim )
485488 l_alpha = broadcast_to_shape (l_alpha , qkv .shape )
486489 beta = broadcast_to_shape (beta , qkv .shape )
487490 l_next_safe = broadcast_to_shape (l_next_safe , qkv .shape )
488491 out = lax .div (
489492 l_alpha * o_curr + beta * qkv ,
490493 l_next_safe ,
491- ). astype ( head_o_ref . dtype )
494+ )
492495 masked_store (
493- head_o_ref ,
494- out .reshape (head_o_ref .shape ),
496+ head_acc_ref ,
497+ out .reshape (head_acc_ref .shape ),
495498 store_start ,
496499 store_end ,
497500 )
@@ -544,7 +547,7 @@ def prefetch_next_kv_blk():
544547 v ,
545548 l_ref .at [kv_head_idx ],
546549 m_ref .at [kv_head_idx ],
547- o_ref .at [:, q_head_idx : q_head_idx + num_q_heads_per_kv_head , :],
550+ acc_ref .at [:, q_head_idx : q_head_idx + num_q_heads_per_kv_head , :],
548551 kv_blk_idx = kv_blk_idx ,
549552 )
550553 return kv_blk_idx + 1 , next_buf_idx
@@ -566,6 +569,7 @@ def prefetch_next_kv_blk():
566569 # Reset seq_idx for next kv_heads_blk if run out of seqs!
567570 seq_buf_idx_ref [0 ] = lax .select (seq_idx < num_seqs , seq_idx , 0 )
568571 seq_buf_idx_ref [1 ] = buf_idx
572+ o_ref [...] = acc_ref [...].astype (q_ref .dtype )
569573
570574
571575def cdiv (a , b ):
@@ -662,6 +666,7 @@ def ragged_paged_attention(
662666 num_seqs: the dynamic number of sequences.
663667 sm_scale: the softmax scale which will be applied to the Q@K^T.
664668 sliding_window: the sliding window size for the attention.
669+ soft_cap: the logit soft cap for the attention.
665670 mask_value: mask value for causal mask.
666671 num_kv_pages_per_block: number of kv pages to be processed in one flash
667672 attention block in the pallas kernel.
@@ -672,7 +677,7 @@ def ragged_paged_attention(
672677 Returns:
673678 The output of the attention.
674679 """
675- check_inputs_shapes (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs )
680+ validate_static_inputs (q , kv_pages , kv_lens , page_indices , cu_q_lens , num_seqs , sliding_window , soft_cap )
676681 if mask_value is None :
677682 mask_value = DEFAULT_MASK_VALUE
678683 _ , num_q_heads , head_dim = q .shape
@@ -710,6 +715,10 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
710715 (num_kv_heads_per_blk , num_q_per_blk * num_q_heads_per_kv_head , 128 ),
711716 jnp .float32 ,
712717 )
718+ acc_scratch = pltpu .VMEM (
719+ (num_q_per_blk , num_q_heads_per_blk , head_dim ),
720+ jnp .float32 ,
721+ )
713722 double_buf_scratch = pltpu .VMEM (
714723 (
715724 2 , # For double buffering during DMA copies.
@@ -725,6 +734,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
725734 pltpu .SemaphoreType .DMA ((2 ,)), # Semaphores for double buffers.
726735 lm_scratch , # l_ref
727736 lm_scratch , # m_ref
737+ acc_scratch ,
728738 ]
729739 scalar_prefetches = (
730740 kv_lens ,
@@ -755,10 +765,8 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
755765 ),
756766 vmem_limit_bytes = vmem_limit_bytes ,
757767 ),
758- out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = jnp . float32 ),
768+ out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = q . dtype ),
759769 name = "ragged_paged_attention_kernel" ,
760770 )
761771
762- # TODO(jevinjiang): Use f32 acc scratch for output! So we only need
763- # to transfer output with desired dtype back to HBM.
764- return kernel (* scalar_prefetches , q , kv_pages ).astype (q .dtype )
772+ return kernel (* scalar_prefetches , q , kv_pages )
0 commit comments