@@ -112,6 +112,7 @@ def gqa_attention_kv_stage1(
112112 V_D_HEAD : tl .constexpr , # Dimension of each key/value head
113113 SEQ_BLOCK_SIZE : tl .constexpr , # Block size used for tiling the sequence dim.
114114 HEAD_BLOCK_SIZE : tl .constexpr , # pad to 16 if HEAD_RATIO is < 16 to invoke tensor cores.
115+ SLIDING_WINDOW : tl .constexpr ,
115116):
116117 """Attention kernel to be used for generate-only batches.
117118
@@ -122,7 +123,7 @@ def gqa_attention_kv_stage1(
122123 Supports non-power-of-2 D_HEAD
123124
124125 Uses flash decoding.
125- KV-cache layout is assumed to be [Batch,Seq, Head, Dim]
126+ KV-cache layout is assumed to be [Batch, Seq, Head, Dim]
126127 1. Fetch the K-cache from 0 to input_pos
127128 2. Fetch the V-cache from 0 to input_pos
128129 3. A = Q*K^T [1,D_HEAD] * [1,seq_len,D_HEAD] -> [1, seq_len]
@@ -145,10 +146,20 @@ def gqa_attention_kv_stage1(
145146
146147 # The number of Q heads that map to each KV head.
147148 HEAD_RATIO : tl .constexpr = N_HEADS // N_KV_HEADS # This needs to be a power-of-2
148- if seq_start_pos > kv_position :
149- return
150- seq_offsets = seq_start_pos + tl .arange (0 , SEQ_BLOCK_SIZE )
151- seq_mask = seq_offsets <= kv_position
149+
150+ # Apply sliding window constraints
151+ if SLIDING_WINDOW > 0 :
152+ # For sliding window, limit the sequence range
153+ sliding_start = tl .maximum (0 , kv_position - SLIDING_WINDOW + 1 )
154+ if seq_start_pos + SEQ_BLOCK_SIZE <= sliding_start or seq_start_pos > kv_position :
155+ return
156+ seq_offsets = seq_start_pos + tl .arange (0 , SEQ_BLOCK_SIZE )
157+ seq_mask = (seq_offsets <= kv_position ) & (seq_offsets >= sliding_start )
158+ else :
159+ if seq_start_pos > kv_position :
160+ return
161+ seq_offsets = seq_start_pos + tl .arange (0 , SEQ_BLOCK_SIZE )
162+ seq_mask = seq_offsets <= kv_position
152163
153164 # Need to pad the head dim to 16 if HEAD_RATIO is < 16 so that tensor cores can be invoked
154165 #
@@ -358,6 +369,8 @@ def attention_kv_stage2(
358369 N_HEADS : tl .constexpr ,
359370 D_HEAD : tl .constexpr ,
360371 SEQ_BLOCK_SIZE : tl .constexpr , # Nearest power of 2 for num_blocks
372+ HAS_SINKS : tl .constexpr ,
373+ sinks_ptr ,
361374):
362375 # There are batch * N_HEADS programs
363376 batch_id = tl .program_id (axis = 0 )
@@ -382,6 +395,11 @@ def attention_kv_stage2(
382395 sumexp = tl .exp (logsumexp - max_logsumexp ) # [NUM_BLOCKS_POW2]
383396
384397 aggregate_sumexp = tl .sum (sumexp , axis = 0 )
398+ # Add sinks contribution to the softmax denominator
399+ if HAS_SINKS :
400+ sinks_val = tl .load (sinks_ptr + batch_id * N_HEADS + head_id )
401+ sinks_exp = tl .exp (sinks_val - max_logsumexp )
402+ aggregate_sumexp += sinks_exp
385403
386404 values_offsets = block_offsets [:, None ] * D_HEAD + dhead_offsets [None , :]
387405 values_mask = block_mask [:, None ] * dhead_mask [None , :]
@@ -573,6 +591,9 @@ def context_attention_kv_flattened(
573591 V_D_HEAD : tl .constexpr , # Dimension of each value head.
574592 SEQ_BLOCK : tl .constexpr ,
575593 MAX_SEQ_LENGTH : tl .constexpr ,
594+ SLIDING_WINDOW : tl .constexpr , # Sliding window size, -1 means no sliding window
595+ HAS_SINKS : tl .constexpr ,
596+ sinks_ptr ,
576597):
577598 """Kernel for context phase.
578599
@@ -623,7 +644,15 @@ def context_attention_kv_flattened(
623644 # input_pos_ptr stores the location at which kv must be written back for the given batch.
624645 kv_position = tl .load (input_pos_ptr + batch_id )
625646 num_blocks = (kv_position + seq_len + SEQ_BLOCK - 1 ) // SEQ_BLOCK
626- for s in range (0 , num_blocks + 1 , 1 ):
647+ start = 0
648+ if SLIDING_WINDOW > 0 :
649+ # Use the LAST query in this block for more conservative start calculation
650+ last_q_pos = (
651+ (seq_block_id + 1 ) * SEQ_BLOCK - 1 + kv_position
652+ ) # Last query's absolute position
653+ earliest_kv_pos = max (0 , last_q_pos - SLIDING_WINDOW + 1 )
654+ start = max (0 , earliest_kv_pos // SEQ_BLOCK )
655+ for s in range (start , num_blocks + 1 ):
627656 kv_seq_offsets = s * SEQ_BLOCK + tl .arange (0 , SEQ_BLOCK )
628657 kv_seq_mask = kv_seq_offsets < (kv_position + seq_len )
629658
@@ -637,9 +666,17 @@ def context_attention_kv_flattened(
637666 )
638667 qk = tl .zeros ([SEQ_BLOCK , SEQ_BLOCK ], dtype = tl .float32 )
639668 qk += tl .dot (q , k .trans ())
640- qk = tl .where (
641- (seq_offsets [:, None ] + kv_position ) >= kv_seq_offsets [None , :], qk , float ("-inf" )
642- )
669+ # Apply causal mask
670+ causal_mask = (seq_offsets [:, None ] + kv_position ) >= kv_seq_offsets [None , :]
671+ # Apply sliding window mask if enabled
672+ if SLIDING_WINDOW > 0 :
673+ sliding_window_mask = kv_seq_offsets [None , :] >= (
674+ seq_offsets [:, None ] + kv_position - SLIDING_WINDOW + 1
675+ )
676+ combined_mask = sliding_window_mask & causal_mask
677+ else :
678+ combined_mask = causal_mask
679+ qk = tl .where (combined_mask , qk , float ("-inf" ))
643680 qk *= SCALE
644681 # rowmax
645682 m_ij = tl .maximum (tl .max (qk , 1 ), lse_i )
@@ -662,6 +699,16 @@ def context_attention_kv_flattened(
662699 l_i_new = tl .exp (lse_i - m_ij ) + l_ij
663700 lse_i = m_ij + tl .log (l_i_new )
664701
702+ # Add sinks contribution to the final softmax calculation
703+ if HAS_SINKS :
704+ sinks_val = tl .load (sinks_ptr + batch_id * N_HEADS + head_id )
705+ m_sinks = tl .maximum (m_i , sinks_val )
706+ acc_scale = tl .exp (m_i - m_sinks )
707+ acc = acc * acc_scale [:, None ]
708+ l_sinks = tl .exp (lse_i - m_sinks ) + tl .exp (sinks_val - m_sinks )
709+ lse_i = m_sinks + tl .log (l_sinks )
710+ m_i = m_sinks
711+
665712 o_scale = tl .exp (m_i - lse_i )
666713
667714 acc = acc * o_scale [:, None ]
0 commit comments