@@ -219,20 +219,28 @@ def forward_kernel_causal_and_sparse(
219219
220220 if EVEN_N & EVEN_M :
221221 if EVEN_HEADDIM :
222- k = tl .load (k_ptrs )
222+ k = tl .load (
223+ k_ptrs ,
224+ mask = (offs_n [:, None ] >= 0 ),
225+ other = 0.
226+ )
223227 else :
224- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
228+ k = tl .load (
229+ k_ptrs ,
230+ mask = (offs_n [:, None ] >= 0 ) & (offs_d [None , :] < headdim ),
231+ other = 0.0
232+ )
225233 else :
226234 if EVEN_HEADDIM :
227235 k = tl .load (
228236 k_ptrs ,
229- mask = offs_n [:, None ] < seqlen_k ,
237+ mask = ( offs_n [:, None ] >= 0 ) & ( offs_n [:, None ] < seqlen_k ) ,
230238 other = 0.0 ,
231239 )
232240 else :
233241 k = tl .load (
234242 k_ptrs ,
235- mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
243+ mask = (offs_n [:, None ] >= 0 ) & ( offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
236244 other = 0.0 ,
237245 )
238246
@@ -1229,19 +1237,36 @@ def backward_kernel_one_col_block_causal(
12291237 # if we just call tl.load(k_ptrs), we get the wrong output!
12301238 if EVEN_N & EVEN_M :
12311239 if EVEN_HEADDIM :
1232- k = tl .load (k_ptrs )
1240+ k = tl .load (
1241+ k_ptrs ,
1242+ mask = (offs_n [:, None ] >= 0 ),
1243+ other = 0.
1244+ )
12331245 v = tl .load (v_ptrs )
12341246 else :
1235- k = tl .load (k_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
1247+ k = tl .load (
1248+ k_ptrs ,
1249+ mask = (offs_n [:, None ] >= 0 ) & (offs_d [None , :] < headdim ),
1250+ other = 0.0
1251+ )
1252+
12361253 v = tl .load (v_ptrs , mask = offs_d [None , :] < headdim , other = 0.0 )
12371254 else :
12381255 if EVEN_HEADDIM :
1239- k = tl .load (k_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
1256+ k = tl .load (
1257+ k_ptrs ,
1258+ mask = (offs_n [:, None ] >= 0 ) & (offs_n [:, None ] < seqlen_k ),
1259+ other = 0.0
1260+ )
1261+
12401262 v = tl .load (v_ptrs , mask = offs_n [:, None ] < seqlen_k , other = 0.0 )
12411263 else :
12421264 k = tl .load (
1243- k_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
1265+ k_ptrs ,
1266+ mask = (offs_n [:, None ] >= 0 ) & (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ),
1267+ other = 0.0
12441268 )
1269+
12451270 v = tl .load (
12461271 v_ptrs , mask = (offs_n [:, None ] < seqlen_k ) & (offs_d [None , :] < headdim ), other = 0.0
12471272 )
@@ -1273,7 +1298,7 @@ def backward_kernel_one_col_block_causal(
12731298
12741299 if BLOCK != SEL_BLOCK :
12751300 block_diagonal_mask = (
1276- (offs_n [None , :] >= 0. ) &
1301+ (offs_n [None , :] >= 0 ) &
12771302 ((offs_n [None , :] // SEL_BLOCK ) == (offs_m [:, None ] // SEL_BLOCK ))
12781303 )
12791304
0 commit comments