Skip to content

Commit a0177eb

Browse files
committed
up
1 parent 827fc15 commit a0177eb

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
355355
)
356356

357357
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
358-
if not _CAN_USE_FLASH_ATTN_3:
358+
if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None):
359359
raise RuntimeError(
360360
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
361361
)
@@ -645,24 +645,34 @@ def _flash_attention_3(
645645
deterministic: bool = False,
646646
return_attn_probs: bool = False,
647647
) -> torch.Tensor:
648-
out, lse, *_ = flash_attn_3_func(
649-
q=query,
650-
k=key,
651-
v=value,
652-
softmax_scale=scale,
653-
causal=is_causal,
654-
qv=None,
655-
q_descale=None,
656-
k_descale=None,
657-
v_descale=None,
658-
window_size=window_size,
659-
attention_chunk=0,
660-
softcap=softcap,
661-
num_splits=1,
662-
pack_gqa=None,
663-
deterministic=deterministic,
664-
sm_margin=0,
665-
)
648+
sig = inspect.signature(flash_attn_3_func)
649+
accepted = set(sig.parameters)
650+
params = {
651+
"q": query,
652+
"k": key,
653+
"v": value,
654+
"softmax_scale": scale,
655+
"causal": is_causal,
656+
"qv": None,
657+
"q_descale": None,
658+
"k_descale": None,
659+
"v_descale": None,
660+
"window_size": window_size,
661+
"attention_chunk": 0,
662+
"softcap": softcap,
663+
"num_splits": 1,
664+
"pack_gqa": None,
665+
"deterministic": deterministic,
666+
"sm_margin": 0,
667+
}
668+
kwargs = {}
669+
for name, value in params.items():
670+
if name not in accepted:
671+
logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.")
672+
else:
673+
kwargs[name] = value
674+
675+
out, lse, *_ = flash_attn_3_func(**kwargs)
666676
return (out, lse) if return_attn_probs else out
667677

668678

0 commit comments

Comments
 (0)