Skip to content

Commit 2c9cdee

Browse files
committed
fix
1 parent 69b4a4c commit 2c9cdee

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,7 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7-
8-
from lightllm.utils.device_utils import get_cuda_device_name, get_device_capability
9-
10-
HOPPER = (
11-
"H100" in get_cuda_device_name()
12-
or "H200" in get_cuda_device_name()
13-
or "H800" in get_cuda_device_name()
14-
or "Hopper" in get_cuda_device_name()
15-
)
16-
7+
from lightllm.utils.device_utils import is_hopper
178

189
if triton.__version__ >= "2.1.0":
1910

@@ -212,7 +203,7 @@ def flash_attention_fwd(q, k, v, o):
212203
统一的 Flash Attention 接口。如果 _flash_attn_forward 存在,
213204
则使用 flash_attention_v3_fwd,否则使用 Triton 版本。
214205
"""
215-
if _flash_attn_v3_available and HOPPER:
206+
if _flash_attn_v3_available and is_hopper():
216207
flash_attention_v3_fwd(q, k, v, o)
217208
else:
218209
_flash_attention_triton_fwd(q, k, v, o)

0 commit comments

Comments
 (0)