Skip to content

Commit 29d8fcc

Browse files
committed
update
1 parent 15ed1d1 commit 29d8fcc

File tree

6 files changed

+165
-18
lines changed

6 files changed

+165
-18
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,19 @@ def dispatch_attention_fn(
198198
scale: Optional[float] = None,
199199
enable_gqa: bool = False,
200200
attention_kwargs: Optional[Dict[str, Any]] = None,
201+
*,
202+
backend: Optional[AttentionBackendName] = None,
201203
) -> torch.Tensor:
202204
attention_kwargs = attention_kwargs or {}
203-
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
205+
206+
if backend is None:
207+
# If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
208+
# variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
209+
backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
210+
else:
211+
backend_name = AttentionBackendName(backend)
212+
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
213+
204214
kwargs = {
205215
"query": query,
206216
"key": key,

src/diffusers/models/attention_processor.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,8 @@ def forward(
997997
class MochiAttnProcessor2_0:
998998
"""Attention processor used in Mochi."""
999999

1000+
_attention_backend = None
1001+
10001002
def __init__(self):
10011003
if not hasattr(F, "scaled_dot_product_attention"):
10021004
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
@@ -1074,7 +1076,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
10741076
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
10751077
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
10761078

1077-
attn_output = dispatch_attention_fn(valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False)
1079+
attn_output = dispatch_attention_fn(
1080+
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False, backend=self._attention_backend
1081+
)
10781082
valid_sequence_length = attn_output.size(2)
10791083
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
10801084
attn_outputs.append(attn_output)
@@ -2274,6 +2278,8 @@ def __call__(
22742278
class FluxAttnProcessor2_0:
22752279
"""Attention processor used typically in processing the SD3-like self-attention projections."""
22762280

2281+
_attention_backend = None
2282+
22772283
def __init__(self):
22782284
if not hasattr(F, "scaled_dot_product_attention"):
22792285
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -2339,7 +2345,13 @@ def __call__(
23392345
key = apply_rotary_emb(key, image_rotary_emb)
23402346

23412347
hidden_states = dispatch_attention_fn(
2342-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2348+
query,
2349+
key,
2350+
value,
2351+
attn_mask=attention_mask,
2352+
dropout_p=0.0,
2353+
is_causal=False,
2354+
backend=self._attention_backend,
23432355
)
23442356

23452357
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -2366,6 +2378,8 @@ def __call__(
23662378
class FluxAttnProcessor2_0_NPU:
23672379
"""Attention processor used typically in processing the SD3-like self-attention projections."""
23682380

2381+
_attention_backend = None
2382+
23692383
def __init__(self):
23702384
if not hasattr(F, "scaled_dot_product_attention"):
23712385
raise ImportError(
@@ -2448,7 +2462,9 @@ def __call__(
24482462
inner_precise=0,
24492463
)[0]
24502464
else:
2451-
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
2465+
hidden_states = dispatch_attention_fn(
2466+
query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend
2467+
)
24522468
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
24532469
hidden_states = hidden_states.to(query.dtype)
24542470

@@ -2472,6 +2488,8 @@ def __call__(
24722488
class FusedFluxAttnProcessor2_0:
24732489
"""Attention processor used typically in processing the SD3-like self-attention projections."""
24742490

2491+
_attention_backend = None
2492+
24752493
def __init__(self):
24762494
if not hasattr(F, "scaled_dot_product_attention"):
24772495
raise ImportError(
@@ -2542,7 +2560,9 @@ def __call__(
25422560
query = apply_rotary_emb(query, image_rotary_emb)
25432561
key = apply_rotary_emb(key, image_rotary_emb)
25442562

2545-
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
2563+
hidden_states = dispatch_attention_fn(
2564+
query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend
2565+
)
25462566

25472567
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25482568
hidden_states = hidden_states.to(query.dtype)
@@ -2567,6 +2587,8 @@ def __call__(
25672587
class FusedFluxAttnProcessor2_0_NPU:
25682588
"""Attention processor used typically in processing the SD3-like self-attention projections."""
25692589

2590+
_attention_backend = None
2591+
25702592
def __init__(self):
25712593
if not hasattr(F, "scaled_dot_product_attention"):
25722594
raise ImportError(
@@ -2653,7 +2675,9 @@ def __call__(
26532675
inner_precise=0,
26542676
)[0]
26552677
else:
2656-
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
2678+
hidden_states = dispatch_attention_fn(
2679+
query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend
2680+
)
26572681

26582682
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
26592683
hidden_states = hidden_states.to(query.dtype)
@@ -2678,6 +2702,8 @@ def __call__(
26782702
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
26792703
"""Flux Attention processor for IP-Adapter."""
26802704

2705+
_attention_backend = None
2706+
26812707
def __init__(
26822708
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
26832709
):
@@ -2775,7 +2801,9 @@ def __call__(
27752801
query = apply_rotary_emb(query, image_rotary_emb)
27762802
key = apply_rotary_emb(key, image_rotary_emb)
27772803

2778-
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
2804+
hidden_states = dispatch_attention_fn(
2805+
query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend
2806+
)
27792807
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
27802808
hidden_states = hidden_states.to(query.dtype)
27812809

@@ -2806,7 +2834,13 @@ def __call__(
28062834
# the output of sdp = (batch, num_heads, seq_len, head_dim)
28072835
# TODO: add support for attn.scale when we move to Torch 2.1
28082836
current_ip_hidden_states = dispatch_attention_fn(
2809-
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2837+
ip_query,
2838+
ip_key,
2839+
ip_value,
2840+
attn_mask=None,
2841+
dropout_p=0.0,
2842+
is_causal=False,
2843+
backend=self._attention_backend,
28102844
)
28112845
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
28122846
batch_size, -1, attn.heads * head_dim
@@ -2825,6 +2859,8 @@ class CogVideoXAttnProcessor2_0:
28252859
query and key vectors, but does not include spatial normalization.
28262860
"""
28272861

2862+
_attention_backend = None
2863+
28282864
def __init__(self):
28292865
if not hasattr(F, "scaled_dot_product_attention"):
28302866
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -2872,7 +2908,13 @@ def __call__(
28722908
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
28732909

28742910
hidden_states = dispatch_attention_fn(
2875-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2911+
query,
2912+
key,
2913+
value,
2914+
attn_mask=attention_mask,
2915+
dropout_p=0.0,
2916+
is_causal=False,
2917+
backend=self._attention_backend,
28762918
)
28772919

28782920
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -2894,6 +2936,8 @@ class FusedCogVideoXAttnProcessor2_0:
28942936
query and key vectors, but does not include spatial normalization.
28952937
"""
28962938

2939+
_attention_backend = None
2940+
28972941
def __init__(self):
28982942
if not hasattr(F, "scaled_dot_product_attention"):
28992943
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -2943,7 +2987,13 @@ def __call__(
29432987
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
29442988

29452989
hidden_states = dispatch_attention_fn(
2946-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2990+
query,
2991+
key,
2992+
value,
2993+
attn_mask=attention_mask,
2994+
dropout_p=0.0,
2995+
is_causal=False,
2996+
backend=self._attention_backend,
29472997
)
29482998

29492999
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -3129,9 +3179,10 @@ class AttnProcessorNPU:
31293179
Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If
31303180
fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is
31313181
not significant.
3132-
31333182
"""
31343183

3184+
_attention_backend = None
3185+
31353186
def __init__(self):
31363187
if not is_torch_npu_available():
31373188
raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.")
@@ -3216,7 +3267,13 @@ def __call__(
32163267
else:
32173268
# TODO: add support for attn.scale when we move to Torch 2.1
32183269
hidden_states = dispatch_attention_fn(
3219-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3270+
query,
3271+
key,
3272+
value,
3273+
attn_mask=attention_mask,
3274+
dropout_p=0.0,
3275+
is_causal=False,
3276+
backend=self._attention_backend,
32203277
)
32213278

32223279
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -3243,6 +3300,8 @@ class AttnProcessor2_0:
32433300
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
32443301
"""
32453302

3303+
_attention_backend = None
3304+
32463305
def __init__(self):
32473306
if not hasattr(F, "scaled_dot_product_attention"):
32483307
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -3310,7 +3369,13 @@ def __call__(
33103369
# the output of sdp = (batch, num_heads, seq_len, head_dim)
33113370
# TODO: add support for attn.scale when we move to Torch 2.1
33123371
hidden_states = dispatch_attention_fn(
3313-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
3372+
query,
3373+
key,
3374+
value,
3375+
attn_mask=attention_mask,
3376+
dropout_p=0.0,
3377+
is_causal=False,
3378+
backend=self._attention_backend,
33143379
)
33153380

33163381
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
@@ -3553,6 +3618,8 @@ class MochiVaeAttnProcessor2_0:
35533618
Attention processor used in Mochi VAE.
35543619
"""
35553620

3621+
_attention_backend = None
3622+
35563623
def __init__(self):
35573624
if not hasattr(F, "scaled_dot_product_attention"):
35583625
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -3614,7 +3681,13 @@ def __call__(
36143681
# the output of sdp = (batch, num_heads, seq_len, head_dim)
36153682
# TODO: add support for attn.scale when we move to Torch 2.1
36163683
hidden_states = dispatch_attention_fn(
3617-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal
3684+
query,
3685+
key,
3686+
value,
3687+
attn_mask=attention_mask,
3688+
dropout_p=0.0,
3689+
is_causal=attn.is_causal,
3690+
backend=self._attention_backend,
36183691
)
36193692

36203693
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)

src/diffusers/models/modeling_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,52 @@ def enable_group_offload(
599599
low_cpu_mem_usage=low_cpu_mem_usage,
600600
)
601601

602+
def set_attention_backend(self, backend: str) -> None:
603+
"""
604+
Set the attention backend for the model.
605+
606+
Args:
607+
backend (`str`):
608+
The name of the backend to set. Must be one of the available backends defined in
609+
`AttentionBackendName`. Available backends can be found in
610+
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
611+
attention as backend.
612+
"""
613+
from .attention_dispatch import AttentionBackendName
614+
from .attention_processor import Attention, MochiAttention
615+
616+
backend = backend.lower()
617+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
618+
if backend not in available_backends:
619+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
620+
621+
backend = AttentionBackendName(backend)
622+
attention_classes = (Attention, MochiAttention)
623+
624+
for module in self.modules():
625+
if not isinstance(module, attention_classes):
626+
continue
627+
processor = module.processor
628+
if processor is None or not hasattr(processor, "_attention_backend"):
629+
continue
630+
processor._attention_backend = backend
631+
632+
def reset_attention_backend(self) -> None:
633+
"""
634+
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
635+
the torch native scaled dot product attention.
636+
"""
637+
from .attention_processor import Attention, MochiAttention
638+
639+
attention_classes = (Attention, MochiAttention)
640+
for module in self.modules():
641+
if not isinstance(module, attention_classes):
642+
continue
643+
processor = module.processor
644+
if processor is None or not hasattr(processor, "_attention_backend"):
645+
continue
646+
processor._attention_backend = None
647+
602648
def save_pretrained(
603649
self,
604650
save_directory: Union[str, os.PathLike],

src/diffusers/models/transformers/transformer_lumina2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class Lumina2AttnProcessor2_0:
7272
used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
7373
"""
7474

75+
_attention_backend = None
76+
7577
def __init__(self):
7678
if not hasattr(F, "scaled_dot_product_attention"):
7779
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -138,7 +140,9 @@ def __call__(
138140
key = key.transpose(1, 2)
139141
value = value.transpose(1, 2)
140142

141-
hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask, scale=softmax_scale)
143+
hidden_states = dispatch_attention_fn(
144+
query, key, value, attn_mask=attention_mask, scale=softmax_scale, backend=self._attention_backend
145+
)
142146
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
143147
hidden_states = hidden_states.type_as(query)
144148

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737

3838
class WanAttnProcessor2_0:
39+
_attention_backend = None
40+
3941
def __init__(self):
4042
if not hasattr(F, "scaled_dot_product_attention"):
4143
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
@@ -92,13 +94,25 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
9294
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
9395

9496
hidden_states_img = dispatch_attention_fn(
95-
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
97+
query,
98+
key_img,
99+
value_img,
100+
attn_mask=None,
101+
dropout_p=0.0,
102+
is_causal=False,
103+
backend=self._attention_backend,
96104
)
97105
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
98106
hidden_states_img = hidden_states_img.type_as(query)
99107

100108
hidden_states = dispatch_attention_fn(
101-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
109+
query,
110+
key,
111+
value,
112+
attn_mask=attention_mask,
113+
dropout_p=0.0,
114+
is_causal=False,
115+
backend=self._attention_backend,
102116
)
103117
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104118
hidden_states = hidden_states.type_as(query)

0 commit comments

Comments
 (0)