Skip to content

Commit 60e5d01

Browse files
authored
add-4090-fa (#1153)
1 parent e2eb4c4 commit 60e5d01

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import torch.nn.functional as F
77
from typing import Optional, Tuple
8-
from lightllm.utils.device_utils import is_hopper
8+
from lightllm.utils.device_utils import is_hopper, is_4090
99

1010
if triton.__version__ >= "2.1.0":
1111

@@ -217,7 +217,7 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen):
217217
则使用 sgl_kernel里的接口,否则使用 Triton 版本。
218218
"""
219219
global _flash_attn_v3_available
220-
if _flash_attn_v3_available and is_hopper():
220+
if _flash_attn_v3_available and (is_hopper() or is_4090()):
221221
try:
222222
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
223223
except Exception as e:

lightllm/utils/device_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def is_hopper():
2424
)
2525

2626

27+
@lru_cache(maxsize=None)
28+
def is_4090():
29+
return "4090" in torch.cuda.get_device_name(0) or "RTX 4090" in torch.cuda.get_device_name(0)
30+
31+
2732
@lru_cache(maxsize=None)
2833
def get_device_sm_count():
2934
import triton

0 commit comments

Comments
 (0)