Skip to content

Commit c052791

Browse files
sayakpaula-r-r-o-w
andauthored
[core] support attention backends for LTX (huggingface#12021)
* support attention backends for lTX * Apply suggestions from code review Co-authored-by: Aryan <[email protected]> * reviewer feedback. --------- Co-authored-by: Aryan <[email protected]>
1 parent 843e3f9 commit c052791

File tree

1 file changed

+103
-23
lines changed

1 file changed

+103
-23
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 103 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The Genmo team and The HuggingFace Team.
1+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
22
# All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import math
1718
from typing import Any, Dict, Optional, Tuple, Union
1819

1920
import torch
2021
import torch.nn as nn
21-
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
27-
from ..attention import FeedForward
28-
from ..attention_processor import Attention
27+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28+
from ..attention_dispatch import dispatch_attention_fn
2929
from ..cache_utils import CacheMixin
3030
from ..embeddings import PixArtAlphaTextProjection
3131
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@
3737

3838

3939
class LTXVideoAttentionProcessor2_0:
40+
def __new__(cls, *args, **kwargs):
41+
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
42+
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43+
44+
return LTXVideoAttnProcessor(*args, **kwargs)
45+
46+
47+
class LTXVideoAttnProcessor:
4048
r"""
41-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42-
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49+
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50+
model. It applies a normalization layer and rotary embedding on the query and key vector.
4351
"""
4452

53+
_attention_backend = None
54+
4555
def __init__(self):
46-
if not hasattr(F, "scaled_dot_product_attention"):
47-
raise ImportError(
48-
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56+
if is_torch_version("<", "2.0"):
57+
raise ValueError(
58+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
4959
)
5060

5161
def __call__(
5262
self,
53-
attn: Attention,
63+
attn: "LTXAttention",
5464
hidden_states: torch.Tensor,
5565
encoder_hidden_states: Optional[torch.Tensor] = None,
5666
attention_mask: Optional[torch.Tensor] = None,
@@ -78,21 +88,91 @@ def __call__(
7888
query = apply_rotary_emb(query, image_rotary_emb)
7989
key = apply_rotary_emb(key, image_rotary_emb)
8090

81-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84-
85-
hidden_states = F.scaled_dot_product_attention(
86-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
91+
query = query.unflatten(2, (attn.heads, -1))
92+
key = key.unflatten(2, (attn.heads, -1))
93+
value = value.unflatten(2, (attn.heads, -1))
94+
95+
hidden_states = dispatch_attention_fn(
96+
query,
97+
key,
98+
value,
99+
attn_mask=attention_mask,
100+
dropout_p=0.0,
101+
is_causal=False,
102+
backend=self._attention_backend,
87103
)
88-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104+
hidden_states = hidden_states.flatten(2, 3)
89105
hidden_states = hidden_states.to(query.dtype)
90106

91107
hidden_states = attn.to_out[0](hidden_states)
92108
hidden_states = attn.to_out[1](hidden_states)
93109
return hidden_states
94110

95111

112+
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113+
_default_processor_cls = LTXVideoAttnProcessor
114+
_available_processors = [LTXVideoAttnProcessor]
115+
116+
def __init__(
117+
self,
118+
query_dim: int,
119+
heads: int = 8,
120+
kv_heads: int = 8,
121+
dim_head: int = 64,
122+
dropout: float = 0.0,
123+
bias: bool = True,
124+
cross_attention_dim: Optional[int] = None,
125+
out_bias: bool = True,
126+
qk_norm: str = "rms_norm_across_heads",
127+
processor=None,
128+
):
129+
super().__init__()
130+
if qk_norm != "rms_norm_across_heads":
131+
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132+
133+
self.head_dim = dim_head
134+
self.inner_dim = dim_head * heads
135+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136+
self.query_dim = query_dim
137+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138+
self.use_bias = bias
139+
self.dropout = dropout
140+
self.out_dim = query_dim
141+
self.heads = heads
142+
143+
norm_eps = 1e-5
144+
norm_elementwise_affine = True
145+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146+
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
147+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
148+
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149+
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
150+
self.to_out = torch.nn.ModuleList([])
151+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
152+
self.to_out.append(torch.nn.Dropout(dropout))
153+
154+
if processor is None:
155+
processor = self._default_processor_cls()
156+
self.set_processor(processor)
157+
158+
def forward(
159+
self,
160+
hidden_states: torch.Tensor,
161+
encoder_hidden_states: Optional[torch.Tensor] = None,
162+
attention_mask: Optional[torch.Tensor] = None,
163+
image_rotary_emb: Optional[torch.Tensor] = None,
164+
**kwargs,
165+
) -> torch.Tensor:
166+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
167+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
168+
if len(unused_kwargs) > 0:
169+
logger.warning(
170+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
171+
)
172+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
173+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
174+
175+
96176
class LTXVideoRotaryPosEmbed(nn.Module):
97177
def __init__(
98178
self,
@@ -231,7 +311,7 @@ def __init__(
231311
super().__init__()
232312

233313
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
234-
self.attn1 = Attention(
314+
self.attn1 = LTXAttention(
235315
query_dim=dim,
236316
heads=num_attention_heads,
237317
kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ def __init__(
240320
cross_attention_dim=None,
241321
out_bias=attention_out_bias,
242322
qk_norm=qk_norm,
243-
processor=LTXVideoAttentionProcessor2_0(),
244323
)
245324

246325
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
247-
self.attn2 = Attention(
326+
self.attn2 = LTXAttention(
248327
query_dim=dim,
249328
cross_attention_dim=cross_attention_dim,
250329
heads=num_attention_heads,
@@ -253,7 +332,6 @@ def __init__(
253332
bias=attention_bias,
254333
out_bias=attention_out_bias,
255334
qk_norm=qk_norm,
256-
processor=LTXVideoAttentionProcessor2_0(),
257335
)
258336

259337
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ def forward(
299377

300378

301379
@maybe_allow_in_graph
302-
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
380+
class LTXVideoTransformer3DModel(
381+
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
382+
):
303383
r"""
304384
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305385

0 commit comments

Comments
 (0)