File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -375,7 +375,7 @@ def _fwd_kernel_flash_attn_v2(
375
375
bn = tl .load (B_Loc + cur_batch * stride_b_loc_b +
376
376
((start_n + offs_n ) // block_size ) * stride_b_loc_s ,
377
377
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
378
- other = 0 )
378
+ other = 0 ). to ( tl . int64 )
379
379
off_k = (
380
380
bn [None , :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
381
381
(offs_d [:, None ] // x ) * stride_k_cache_d +
@@ -583,7 +583,7 @@ def _fwd_kernel_alibi(
583
583
bn = tl .load (B_Loc + cur_batch * stride_b_loc_b +
584
584
((start_n + offs_n ) // block_size ) * stride_b_loc_s ,
585
585
mask = (start_n + offs_n ) < cur_batch_ctx_len ,
586
- other = 0 )
586
+ other = 0 ). to ( tl . int64 )
587
587
off_k = (
588
588
bn [None , :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
589
589
(offs_d [:, None ] // x ) * stride_k_cache_d +
You can’t perform that action at this time.
0 commit comments