Skip to content

Commit ade3885

Browse files
authored
bugfix: Fix crash when pos_encoding_mode is passed as int (#1413)
This PR is to fix this issue caused by 85d75ca ``` File "/scratch/repo/flashinfer/flashinfer/prefill.py", line 83, in get_fmha_module pos_encoding_mode.value, ^^^^^^^^^^^^^^^^^^^^^^^ AttributeError: 'int' object has no attribute 'value' ``` Update get_fmha_module to expect pos_encoding_mode as an Enum instead of an int; raises error otherwise. cc @yzh119
1 parent 467bb92 commit ade3885

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

flashinfer/prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def get_fmha_module(
6666
dtype_idx: torch.dtype,
6767
head_dim_qk: int,
6868
head_dim_vo: int,
69-
pos_encoding_mode: PosEncodingMode,
69+
pos_encoding_mode: int,
7070
use_sliding_window: bool,
7171
use_logits_soft_cap: bool,
7272
use_fp16_qk_reduction: bool = False,
@@ -79,7 +79,7 @@ def get_fmha_module(
7979
dtype_idx,
8080
head_dim_qk,
8181
head_dim_vo,
82-
pos_encoding_mode.value,
82+
pos_encoding_mode,
8383
use_sliding_window,
8484
use_logits_soft_cap,
8585
).build_and_load()

0 commit comments

Comments
 (0)