Skip to content

Commit 45a809a

Browse files
committed
add support for FA3, NPU, XLA
1 parent 66deeac commit 45a809a

File tree

3 files changed

+207
-34
lines changed

3 files changed

+207
-34
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 201 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 The HuggingFace Team. All rights reserved.
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -15,43 +15,49 @@
1515
import contextlib
1616
import functools
1717
import inspect
18+
import math
1819
from enum import Enum
1920
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
2021

2122
import torch
2223

2324
from ..utils import (
24-
OptionalDependencyNotAvailable,
2525
get_logger,
26+
is_flash_attn_3_available,
2627
is_flash_attn_available,
2728
is_flash_attn_version,
2829
is_sageattention_available,
2930
is_sageattention_version,
31+
is_torch_npu_available,
3032
is_torch_version,
33+
is_torch_xla_available,
34+
is_torch_xla_version,
3135
is_xformers_available,
3236
is_xformers_version,
3337
)
3438
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
3539

3640

37-
if is_flash_attn_available():
38-
if is_flash_attn_version("<", "2.6.3"):
39-
raise OptionalDependencyNotAvailable(
40-
"The `flash-attn` library version is too old. Please update it to at least 2.6.3."
41-
)
41+
logger = get_logger(__name__) # pylint: disable=invalid-name
42+
4243

44+
if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"):
4345
from flash_attn import flash_attn_func, flash_attn_varlen_func
4446
else:
47+
logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.")
4548
flash_attn_func = None
4649
flash_attn_varlen_func = None
4750

4851

49-
if is_sageattention_available():
50-
if is_sageattention_version("<", "2.1.1"):
51-
raise OptionalDependencyNotAvailable(
52-
"The `sageattention` library version is too old. Please update it to at least 2.1.1."
53-
)
52+
if is_flash_attn_3_available():
53+
from flash_attn_interface import flash_attn_func as flash_attn_3_func
54+
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
55+
else:
56+
flash_attn_3_func = None
57+
flash_attn_3_varlen_func = None
5458

59+
60+
if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"):
5561
from sageattention import (
5662
sageattn,
5763
sageattn_qk_int8_pv_fp8_cuda,
@@ -61,6 +67,9 @@
6167
sageattn_varlen,
6268
)
6369
else:
70+
logger.warning(
71+
"`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`."
72+
)
6473
sageattn = None
6574
sageattn_qk_int8_pv_fp16_cuda = None
6675
sageattn_qk_int8_pv_fp16_triton = None
@@ -76,19 +85,25 @@
7685
import torch.nn.attention.flex_attention as flex_attention
7786

7887

79-
if is_xformers_available():
80-
if is_xformers_version("<", "0.0.29"):
81-
raise OptionalDependencyNotAvailable(
82-
"The `xformers` library version is too old. Please update it to at least 0.0.29."
83-
)
88+
if is_torch_npu_available():
89+
from torch_npu import npu_fusion_attention
90+
else:
91+
npu_fusion_attention = None
8492

93+
94+
if is_torch_xla_available() and is_torch_xla_version(">", "2.2"):
95+
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
96+
else:
97+
xla_flash_attention = None
98+
99+
100+
if is_xformers_available() and is_xformers_version(">=", "0.0.29"):
85101
import xformers.ops as xops
86102
else:
103+
logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.")
87104
xops = None
88105

89106

90-
logger = get_logger(__name__) # pylint: disable=invalid-name
91-
92107
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
93108
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
94109
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
@@ -100,6 +115,8 @@ class AttentionBackendName(str, Enum):
100115
# `flash-attn`
101116
FLASH = "flash"
102117
FLASH_VARLEN = "flash_varlen"
118+
_FLASH_3 = "_flash_3"
119+
_FLASH_VARLEN_3 = "_flash_varlen_3"
103120

104121
# PyTorch native
105122
FLEX = "flex"
@@ -108,6 +125,8 @@ class AttentionBackendName(str, Enum):
108125
_NATIVE_EFFICIENT = "_native_efficient"
109126
_NATIVE_FLASH = "_native_flash"
110127
_NATIVE_MATH = "_native_math"
128+
_NATIVE_NPU = "_native_npu"
129+
_NATIVE_XLA = "_native_xla"
111130

112131
# `sageattention`
113132
SAGE = "sage"
@@ -274,7 +293,7 @@ def _check_shape(
274293
# ===== Helper functions =====
275294

276295

277-
@functools.lru_cache(maxsize=1)
296+
@functools.lru_cache(maxsize=8)
278297
def _prepare_for_flash_attn_or_sage_varlen(
279298
batch_size: int,
280299
seq_len_q: int,
@@ -371,12 +390,7 @@ def _flash_attention(
371390
alibi_slopes: Optional[torch.Tensor] = None,
372391
deterministic: bool = False,
373392
return_attn_probs: bool = False,
374-
attn_mask: Optional[torch.Tensor] = None,
375-
enable_gqa: bool = False,
376393
) -> torch.Tensor:
377-
if enable_gqa:
378-
raise NotImplementedError("GQA is not yet supported.")
379-
380394
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
381395
out = flash_attn_func(
382396
q=query,
@@ -392,7 +406,6 @@ def _flash_attention(
392406
return_attn_probs=return_attn_probs,
393407
)
394408
out = out.permute(0, 2, 1, 3)
395-
396409
return out
397410

398411

@@ -417,17 +430,13 @@ def _flash_varlen_attention(
417430
deterministic: bool = False,
418431
return_attn_probs: bool = False,
419432
attn_mask: Optional[torch.Tensor] = None,
420-
enable_gqa: bool = False,
421433
) -> torch.Tensor:
422434
batch_size, _, seq_len_q, _ = query.shape
423435
_, _, seq_len_kv, _ = key.shape
424436

425437
if attn_mask is not None:
426438
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
427439

428-
if enable_gqa:
429-
raise NotImplementedError("GQA is not yet supported.")
430-
431440
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
432441
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
433442
_prepare_for_flash_attn_or_sage_varlen(
@@ -473,6 +482,121 @@ def _flash_varlen_attention(
473482
return out
474483

475484

485+
@_AttentionBackendRegistry.register(
486+
AttentionBackendName._FLASH_3,
487+
constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
488+
)
489+
def _flash_attention_3(
490+
query: torch.Tensor,
491+
key: torch.Tensor,
492+
value: torch.Tensor,
493+
scale: Optional[float] = None,
494+
is_causal: bool = False,
495+
window_size: Tuple[int, int] = (-1, -1),
496+
softcap: float = 0.0,
497+
deterministic: bool = False,
498+
return_attn_probs: bool = False,
499+
) -> torch.Tensor:
500+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
501+
out, lse, *_ = flash_attn_3_func(
502+
q=query,
503+
k=key,
504+
v=value,
505+
softmax_scale=scale,
506+
causal=is_causal,
507+
qv=None,
508+
q_descale=None,
509+
k_descale=None,
510+
v_descale=None,
511+
window_size=window_size,
512+
attention_chunk=0,
513+
softcap=softcap,
514+
num_splits=1,
515+
pack_gqa=None,
516+
deterministic=deterministic,
517+
sm_margin=0,
518+
)
519+
out = out.permute(0, 2, 1, 3)
520+
return (out, lse) if return_attn_probs else out
521+
522+
523+
@_AttentionBackendRegistry.register(
524+
AttentionBackendName._FLASH_VARLEN_3,
525+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
526+
)
527+
def _flash_varlen_attention_3(
528+
query: torch.Tensor,
529+
key: torch.Tensor,
530+
value: torch.Tensor,
531+
cu_seqlens_q: Optional[torch.Tensor] = None,
532+
cu_seqlens_k: Optional[torch.Tensor] = None,
533+
max_seqlen_q: Optional[int] = None,
534+
max_seqlen_k: Optional[int] = None,
535+
scale: Optional[float] = None,
536+
is_causal: bool = False,
537+
window_size: Tuple[int, int] = (-1, -1),
538+
softcap: float = 0.0,
539+
deterministic: bool = False,
540+
return_attn_probs: bool = False,
541+
attn_mask: Optional[torch.Tensor] = None,
542+
) -> torch.Tensor:
543+
batch_size, _, seq_len_q, _ = query.shape
544+
_, _, seq_len_kv, _ = key.shape
545+
546+
if attn_mask is not None:
547+
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
548+
549+
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
550+
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
551+
_prepare_for_flash_attn_or_sage_varlen(
552+
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
553+
)
554+
)
555+
else:
556+
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
557+
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
558+
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
559+
560+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
561+
562+
key_valid, value_valid = [], []
563+
for b in range(batch_size):
564+
valid_len = seqlens_k[b]
565+
key_valid.append(key[b, :valid_len])
566+
value_valid.append(value[b, :valid_len])
567+
568+
query_packed = query.flatten(0, 1)
569+
key_packed = torch.cat(key_valid, dim=0)
570+
value_packed = torch.cat(value_valid, dim=0)
571+
572+
out, lse, *_ = flash_attn_3_varlen_func(
573+
q=query_packed,
574+
k=key_packed,
575+
v=value_packed,
576+
cu_seqlens_q=cu_seqlens_q,
577+
cu_seqlens_k=cu_seqlens_k,
578+
max_seqlen_q=max_seqlen_q,
579+
max_seqlen_k=max_seqlen_k,
580+
seqused_q=None,
581+
seqused_k=None,
582+
softmax_scale=scale,
583+
causal=is_causal,
584+
qv=None,
585+
q_descale=None,
586+
k_descale=None,
587+
v_descale=None,
588+
window_size=window_size,
589+
softcap=softcap,
590+
num_splits=1,
591+
pack_gqa=None,
592+
deterministic=deterministic,
593+
sm_margin=0,
594+
)
595+
out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3)
596+
597+
return (out, lse) if return_attn_probs else out
598+
599+
476600
@_AttentionBackendRegistry.register(
477601
AttentionBackendName.FLEX,
478602
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
@@ -668,6 +792,53 @@ def _native_math_attention(
668792
)
669793

670794

795+
@_AttentionBackendRegistry.register(
796+
AttentionBackendName._NATIVE_NPU,
797+
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
798+
)
799+
def _native_npu_attention(
800+
query: torch.Tensor,
801+
key: torch.Tensor,
802+
value: torch.Tensor,
803+
dropout_p: float = 0.0,
804+
scale: Optional[float] = None,
805+
) -> torch.Tensor:
806+
return npu_fusion_attention(
807+
query,
808+
key,
809+
value,
810+
query.size(1), # num_heads
811+
input_layout="BNSD",
812+
pse=None,
813+
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
814+
pre_tockens=65536,
815+
next_tokens=65536,
816+
keep_prob=1.0 - dropout_p,
817+
sync=False,
818+
inner_precise=0,
819+
)[0]
820+
821+
822+
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
823+
@_AttentionBackendRegistry.register(
824+
AttentionBackendName._NATIVE_XLA,
825+
constraints=[_check_device, _check_shape],
826+
)
827+
def _native_xla_attention(
828+
query: torch.Tensor,
829+
key: torch.Tensor,
830+
value: torch.Tensor,
831+
is_causal: bool = False,
832+
) -> torch.Tensor:
833+
query = query / math.sqrt(query.shape[-1])
834+
return xla_flash_attention(
835+
q=query,
836+
k=key,
837+
v=value,
838+
causal=is_causal,
839+
)
840+
841+
671842
@_AttentionBackendRegistry.register(
672843
AttentionBackendName.SAGE,
673844
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -707,17 +878,13 @@ def _sage_varlen_attention(
707878
scale: Optional[float] = None,
708879
smooth_k: bool = True,
709880
attn_mask: Optional[torch.Tensor] = None,
710-
enable_gqa: bool = False,
711881
) -> torch.Tensor:
712882
batch_size, _, seq_len_q, _ = query.shape
713883
_, _, seq_len_kv, _ = key.shape
714884

715885
if attn_mask is not None:
716886
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
717887

718-
if enable_gqa:
719-
raise NotImplementedError("GQA is not yet supported.")
720-
721888
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
722889
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
723890
_prepare_for_flash_attn_or_sage_varlen(

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
is_bitsandbytes_version,
6868
is_bs4_available,
6969
is_cosmos_guardrail_available,
70+
is_flash_attn_3_available,
7071
is_flash_attn_available,
7172
is_flash_attn_version,
7273
is_flax_available,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
221221
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
222222
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
223223
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
224+
_flash_attn_3_available, _flash_attn_version = _is_package_available("flash_attn_3")
224225

225226

226227
def is_torch_available():
@@ -387,6 +388,10 @@ def is_flash_attn_available():
387388
return _flash_attn_available
388389

389390

391+
def is_flash_attn_3_available():
392+
return _flash_attn_3_available
393+
394+
390395
# docstyle-ignore
391396
FLAX_IMPORT_ERROR = """
392397
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

0 commit comments

Comments
 (0)