Skip to content

Commit d457cf2

Browse files
rasmithmicah-wil
authored andcommitted
add additional casts into _fwd_kernel_flash_attn_v2 and _fwd_kernel_alibi
Signed-off-by: Randall Smith <[email protected]>
1 parent 5929779 commit d457cf2

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

vllm/attention/ops/prefix_prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def _fwd_kernel_flash_attn_v2(
375375
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
376376
((start_n + offs_n) // block_size) * stride_b_loc_s,
377377
mask=(start_n + offs_n) < cur_batch_ctx_len,
378-
other=0)
378+
other=0).to(tl.int64)
379379
off_k = (
380380
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
381381
(offs_d[:, None] // x) * stride_k_cache_d +
@@ -583,7 +583,7 @@ def _fwd_kernel_alibi(
583583
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
584584
((start_n + offs_n) // block_size) * stride_b_loc_s,
585585
mask=(start_n + offs_n) < cur_batch_ctx_len,
586-
other=0)
586+
other=0).to(tl.int64)
587587
off_k = (
588588
bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h +
589589
(offs_d[:, None] // x) * stride_k_cache_d +

0 commit comments

Comments
 (0)