Skip to content

Commit 9d4368d

Browse files
authored
use current_platform.fp8_dtype() for FA (ROCm#483)
Signed-off-by: Divakar Verma <[email protected]>
1 parent f9d626f commit 9d4368d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

vllm/attention/ops/triton_flash_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import triton
2626
import triton.language as tl
2727

28+
from vllm.platforms import current_platform
2829
from vllm.utils import is_navi
2930

3031
torch_dtype: tl.constexpr = torch.float16
@@ -391,7 +392,7 @@ def get_autotune_configs():
391392

392393
autotune_configs, autotune_keys = get_autotune_configs()
393394

394-
float8_info = torch.finfo(torch.float8_e4m3fnuz)
395+
float8_info = torch.finfo(current_platform.fp8_dtype())
395396

396397

397398
@triton.autotune(
@@ -834,7 +835,7 @@ def forward(
834835
if fp8_scales is not None:
835836
use_fp8 = True
836837
(q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales
837-
float8 = torch.float8_e4m3fnuz
838+
float8 = current_platform.fp8_dtype()
838839

839840
def check_and_convert(t, scale):
840841
if t.dtype != float8:

0 commit comments

Comments
 (0)