Skip to content

Commit d513efa

Browse files
committed
support attention backends for lTX
1 parent c02c4a6 commit d513efa

File tree

1 file changed

+102
-23
lines changed

1 file changed

+102
-23
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2025 The Genmo team and The HuggingFace Team.
1+
# Copyright 2025 The Lightricks team and The HuggingFace Team.
22
# All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import inspect
1617
import math
1718
from typing import Any, Dict, Optional, Tuple, Union
1819

1920
import torch
2021
import torch.nn as nn
21-
import torch.nn.functional as F
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
2424
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
25-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
2626
from ...utils.torch_utils import maybe_allow_in_graph
27-
from ..attention import FeedForward
28-
from ..attention_processor import Attention
27+
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
28+
from ..attention_dispatch import dispatch_attention_fn
2929
from ..cache_utils import CacheMixin
3030
from ..embeddings import PixArtAlphaTextProjection
3131
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@
3737

3838

3939
class LTXVideoAttentionProcessor2_0:
40+
def __new__(cls, *args, **kwargs):
41+
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttentionProcessor2_0`"
42+
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
43+
44+
return LTXAttnProcessor(*args, **kwargs)
45+
46+
47+
class LTXAttnProcessor:
4048
r"""
41-
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42-
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49+
Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50+
model. It applies a normalization layer and rotary embedding on the query and key vector.
4351
"""
4452

53+
_attention_backend = None
54+
4555
def __init__(self):
46-
if not hasattr(F, "scaled_dot_product_attention"):
47-
raise ImportError(
48-
"LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
56+
if is_torch_version("<", "2.0"):
57+
raise ValueError(
58+
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
4959
)
5060

5161
def __call__(
5262
self,
53-
attn: Attention,
63+
attn: "LTXAttention",
5464
hidden_states: torch.Tensor,
5565
encoder_hidden_states: Optional[torch.Tensor] = None,
5666
attention_mask: Optional[torch.Tensor] = None,
@@ -78,21 +88,90 @@ def __call__(
7888
query = apply_rotary_emb(query, image_rotary_emb)
7989
key = apply_rotary_emb(key, image_rotary_emb)
8090

81-
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
82-
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
83-
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
84-
85-
hidden_states = F.scaled_dot_product_attention(
86-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
91+
query = query.unflatten(2, (attn.heads, -1))
92+
key = key.unflatten(2, (attn.heads, -1))
93+
value = value.unflatten(2, (attn.heads, -1))
94+
95+
hidden_states = dispatch_attention_fn(
96+
query,
97+
key,
98+
value,
99+
attn_mask=attention_mask,
100+
dropout_p=0.0,
101+
is_causal=False,
102+
backend=self._attention_backend,
87103
)
88-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
104+
hidden_states = hidden_states.flatten(2, 3)
89105
hidden_states = hidden_states.to(query.dtype)
90106

91107
hidden_states = attn.to_out[0](hidden_states)
92108
hidden_states = attn.to_out[1](hidden_states)
93109
return hidden_states
94110

95111

112+
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113+
_default_processor_cls = LTXAttnProcessor
114+
_available_processors = [LTXAttnProcessor]
115+
116+
def __init__(
117+
self,
118+
query_dim: int,
119+
heads: int = 8,
120+
kv_heads: int = 8,
121+
dim_head: int = 64,
122+
dropout: float = 0.0,
123+
bias: bool = True,
124+
cross_attention_dim: Optional[int] = None,
125+
out_bias: bool = True,
126+
qk_norm: str = "rms_norm_across_heads",
127+
processor=None,
128+
):
129+
super().__init__()
130+
if qk_norm != "rms_norm_across_heads":
131+
raise NotImplementedError
132+
133+
self.head_dim = dim_head
134+
self.inner_dim = dim_head * heads
135+
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
136+
self.query_dim = query_dim
137+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138+
self.use_bias = bias
139+
self.dropout = dropout
140+
self.out_dim = query_dim
141+
self.heads = heads
142+
143+
norm_eps, norm_elementwise_affine = 1e-5, True
144+
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
145+
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146+
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
147+
self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
148+
self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
149+
self.to_out = torch.nn.ModuleList([])
150+
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
151+
self.to_out.append(torch.nn.Dropout(dropout))
152+
153+
if processor is None:
154+
processor = self._default_processor_cls()
155+
self.set_processor(processor)
156+
157+
def forward(
158+
self,
159+
hidden_states: torch.Tensor,
160+
encoder_hidden_states: Optional[torch.Tensor] = None,
161+
attention_mask: Optional[torch.Tensor] = None,
162+
image_rotary_emb: Optional[torch.Tensor] = None,
163+
**kwargs,
164+
) -> torch.Tensor:
165+
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
166+
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
167+
if len(unused_kwargs) > 0:
168+
logger.warning(
169+
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
170+
)
171+
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
172+
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
173+
174+
96175
class LTXVideoRotaryPosEmbed(nn.Module):
97176
def __init__(
98177
self,
@@ -231,7 +310,7 @@ def __init__(
231310
super().__init__()
232311

233312
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
234-
self.attn1 = Attention(
313+
self.attn1 = LTXAttention(
235314
query_dim=dim,
236315
heads=num_attention_heads,
237316
kv_heads=num_attention_heads,
@@ -240,11 +319,10 @@ def __init__(
240319
cross_attention_dim=None,
241320
out_bias=attention_out_bias,
242321
qk_norm=qk_norm,
243-
processor=LTXVideoAttentionProcessor2_0(),
244322
)
245323

246324
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
247-
self.attn2 = Attention(
325+
self.attn2 = LTXAttention(
248326
query_dim=dim,
249327
cross_attention_dim=cross_attention_dim,
250328
heads=num_attention_heads,
@@ -253,7 +331,6 @@ def __init__(
253331
bias=attention_bias,
254332
out_bias=attention_out_bias,
255333
qk_norm=qk_norm,
256-
processor=LTXVideoAttentionProcessor2_0(),
257334
)
258335

259336
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +376,9 @@ def forward(
299376

300377

301378
@maybe_allow_in_graph
302-
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
379+
class LTXVideoTransformer3DModel(
380+
ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
381+
):
303382
r"""
304383
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305384

0 commit comments

Comments
 (0)