@@ -47,7 +47,8 @@ def _fwd_kernel(
47
47
stride_v_cache_bl ,
48
48
num_queries_per_kv : int ,
49
49
BLOCK_M : tl .constexpr ,
50
- BLOCK_DMODEL : tl .constexpr ,
50
+ BLOCK_DMODEL : tl .constexpr , # head size
51
+ BLOCK_DMODEL_PADDED : tl .constexpr , # head size padded to a power of 2
51
52
BLOCK_N : tl .constexpr ,
52
53
):
53
54
cur_batch = tl .program_id (0 )
@@ -59,26 +60,30 @@ def _fwd_kernel(
59
60
cur_batch_ctx_len = tl .load (B_Ctxlen + cur_batch )
60
61
cur_batch_seq_len = tl .load (B_Seqlen + cur_batch )
61
62
cur_batch_in_all_start_index = tl .load (B_Start_Loc + cur_batch )
63
+ cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
62
64
63
65
block_start_loc = BLOCK_M * start_m
64
66
65
67
# initialize offsets
66
68
offs_n = tl .arange (0 , BLOCK_N )
67
- offs_d = tl .arange (0 , BLOCK_DMODEL )
69
+ offs_d = tl .arange (0 , BLOCK_DMODEL_PADDED )
68
70
offs_m = start_m * BLOCK_M + tl .arange (0 , BLOCK_M )
69
71
off_q = (
70
72
(cur_batch_in_all_start_index + offs_m [:, None ]) * stride_qbs +
71
73
cur_head * stride_qh + offs_d [None , :] * stride_qd )
72
74
73
- q = tl .load (
74
- Q + off_q ,
75
- mask = offs_m [:, None ] < cur_batch_seq_len - cur_batch_ctx_len ,
76
- other = 0.0 )
75
+ dim_mask = tl .where (
76
+ tl .arange (0 , BLOCK_DMODEL_PADDED ) < BLOCK_DMODEL , 1 , 0 ).to (tl .int1 )
77
+
78
+ q = tl .load (Q + off_q ,
79
+ mask = dim_mask [None , :] &
80
+ (offs_m [:, None ] < cur_batch_query_len ),
81
+ other = 0.0 )
77
82
78
83
# # initialize pointer to m and l
79
84
m_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) - float ("inf" )
80
85
l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 )
81
- acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL ], dtype = tl .float32 )
86
+ acc = tl .zeros ([BLOCK_M , BLOCK_DMODEL_PADDED ], dtype = tl .float32 )
82
87
83
88
for start_n in range (0 , cur_batch_ctx_len , BLOCK_N ):
84
89
start_n = tl .multiple_of (start_n , BLOCK_N )
@@ -99,7 +104,8 @@ def _fwd_kernel(
99
104
offs_d [None , :] * stride_v_cache_d +
100
105
(start_n + offs_n [:, None ]) % block_size * stride_v_cache_bl )
101
106
k = tl .load (K_cache + off_k ,
102
- mask = (start_n + offs_n [None , :]) < cur_batch_ctx_len ,
107
+ mask = dim_mask [:, None ] &
108
+ ((start_n + offs_n [None , :]) < cur_batch_ctx_len ),
103
109
other = 0.0 )
104
110
105
111
qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
@@ -126,7 +132,8 @@ def _fwd_kernel(
126
132
acc = acc * acc_scale [:, None ]
127
133
# update acc
128
134
v = tl .load (V_cache + off_v ,
129
- mask = (start_n + offs_n [:, None ]) < cur_batch_ctx_len ,
135
+ mask = dim_mask [None , :] &
136
+ ((start_n + offs_n [:, None ]) < cur_batch_ctx_len ),
130
137
other = 0.0 )
131
138
132
139
p = p .to (v .dtype )
@@ -142,16 +149,15 @@ def _fwd_kernel(
142
149
k_ptrs = K + off_k
143
150
v_ptrs = V + off_v
144
151
145
- block_mask = tl .where (
146
- block_start_loc < cur_batch_seq_len - cur_batch_ctx_len , 1 , 0 )
152
+ block_mask = tl .where (block_start_loc < cur_batch_query_len , 1 , 0 )
147
153
148
154
for start_n in range (0 , block_mask * (start_m + 1 ) * BLOCK_M , BLOCK_N ):
149
155
start_n = tl .multiple_of (start_n , BLOCK_N )
150
156
# -- compute qk ----
151
157
k = tl .load (k_ptrs +
152
158
(cur_batch_in_all_start_index + start_n ) * stride_kbs ,
153
- mask = ( start_n + offs_n [ None , :]) <
154
- cur_batch_seq_len - cur_batch_ctx_len ,
159
+ mask = dim_mask [:, None ] &
160
+ (( start_n + offs_n [ None , :]) < cur_batch_query_len ) ,
155
161
other = 0.0 )
156
162
157
163
qk = tl .zeros ([BLOCK_M , BLOCK_N ], dtype = tl .float32 )
@@ -179,8 +185,8 @@ def _fwd_kernel(
179
185
# update acc
180
186
v = tl .load (v_ptrs +
181
187
(cur_batch_in_all_start_index + start_n ) * stride_vbs ,
182
- mask = ( start_n + offs_n [:, None ]) <
183
- cur_batch_seq_len - cur_batch_ctx_len ,
188
+ mask = dim_mask [ None , :] &
189
+ (( start_n + offs_n [:, None ]) < cur_batch_query_len ) ,
184
190
other = 0.0 )
185
191
186
192
p = p .to (v .dtype )
@@ -195,7 +201,8 @@ def _fwd_kernel(
195
201
out_ptrs = Out + off_o
196
202
tl .store (out_ptrs ,
197
203
acc ,
198
- mask = offs_m [:, None ] < cur_batch_seq_len - cur_batch_ctx_len )
204
+ mask = dim_mask [None , :] &
205
+ (offs_m [:, None ] < cur_batch_query_len ))
199
206
return
200
207
201
208
@triton .jit
@@ -636,7 +643,8 @@ def context_attention_fwd(q,
636
643
# shape constraints
637
644
Lq , Lk , Lv = q .shape [- 1 ], k .shape [- 1 ], v .shape [- 1 ]
638
645
assert Lq == Lk and Lk == Lv
639
- assert Lk in {16 , 32 , 64 , 128 }
646
+ # round up Lk to a power of 2 - this is required for Triton block size
647
+ Lk_padded = 2 ** ((Lk - 1 ).bit_length ())
640
648
641
649
sm_scale = 1.0 / (Lq ** 0.5 )
642
650
batch , head = b_seq_len .shape [0 ], q .shape [1 ]
@@ -646,6 +654,7 @@ def context_attention_fwd(q,
646
654
647
655
num_warps = 8 if Lk <= 64 else 8
648
656
if alibi_slopes is not None :
657
+ assert Lk == Lk_padded
649
658
_fwd_kernel_alibi [grid ](
650
659
q ,
651
660
k ,
@@ -738,6 +747,7 @@ def context_attention_fwd(q,
738
747
num_queries_per_kv = num_queries_per_kv ,
739
748
BLOCK_M = BLOCK ,
740
749
BLOCK_DMODEL = Lk ,
750
+ BLOCK_DMODEL_PADDED = Lk_padded ,
741
751
BLOCK_N = BLOCK ,
742
752
num_warps = num_warps ,
743
753
num_stages = 1 ,
0 commit comments