Skip to content

Commit bacb278

Browse files
authored
vit fa3 api fix (#1047)
1 parent ae0ac10 commit bacb278

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def flash_attention_v3_fwd(
201201
num_splits=1,
202202
pack_gqa=None,
203203
sm_margin=0,
204+
sinks=None,
204205
)
205206

206207
return
@@ -215,7 +216,12 @@ def flash_attention_fwd(q, k, v, o, cu_seqlens, max_seqlen):
215216
统一的 Flash Attention 接口。如果 sgl_kernel 存在,
216217
则使用 sgl_kernel里的接口,否则使用 Triton 版本。
217218
"""
219+
global _flash_attn_v3_available
218220
if _flash_attn_v3_available and is_hopper():
219-
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
221+
try:
222+
flash_attention_v3_fwd(q, k, v, o, cu_seqlens, max_seqlen)
223+
except Exception as e:
224+
print(f"Failed to use sgl_kernel: {e}")
225+
_flash_attn_v3_available = False
220226
else:
221227
_flash_attention_triton_fwd(q, k, v, o, cu_seqlens, max_seqlen)

0 commit comments

Comments
 (0)