You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
361
361
)
@@ -645,24 +645,34 @@ def _flash_attention_3(
645
645
deterministic: bool=False,
646
646
return_attn_probs: bool=False,
647
647
) ->torch.Tensor:
648
-
out, lse, *_=flash_attn_3_func(
649
-
q=query,
650
-
k=key,
651
-
v=value,
652
-
softmax_scale=scale,
653
-
causal=is_causal,
654
-
qv=None,
655
-
q_descale=None,
656
-
k_descale=None,
657
-
v_descale=None,
658
-
window_size=window_size,
659
-
attention_chunk=0,
660
-
softcap=softcap,
661
-
num_splits=1,
662
-
pack_gqa=None,
663
-
deterministic=deterministic,
664
-
sm_margin=0,
665
-
)
648
+
sig=inspect.signature(flash_attn_3_func)
649
+
accepted=set(sig.parameters)
650
+
params= {
651
+
"q": query,
652
+
"k": key,
653
+
"v": value,
654
+
"softmax_scale": scale,
655
+
"causal": is_causal,
656
+
"qv": None,
657
+
"q_descale": None,
658
+
"k_descale": None,
659
+
"v_descale": None,
660
+
"window_size": window_size,
661
+
"attention_chunk": 0,
662
+
"softcap": softcap,
663
+
"num_splits": 1,
664
+
"pack_gqa": None,
665
+
"deterministic": deterministic,
666
+
"sm_margin": 0,
667
+
}
668
+
kwargs= {}
669
+
forname, valueinparams.items():
670
+
ifnamenotinaccepted:
671
+
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
0 commit comments