1- # Copyright 2025 The Genmo team and The HuggingFace Team.
1+ # Copyright 2025 The Lightricks team and The HuggingFace Team.
22# All rights reserved.
33#
44# Licensed under the Apache License, Version 2.0 (the "License");
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import inspect
1617import math
1718from typing import Any , Dict , Optional , Tuple , Union
1819
1920import torch
2021import torch .nn as nn
21- import torch .nn .functional as F
2222
2323from ...configuration_utils import ConfigMixin , register_to_config
2424from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
25- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
25+ from ...utils import USE_PEFT_BACKEND , deprecate , is_torch_version , logging , scale_lora_layers , unscale_lora_layers
2626from ...utils .torch_utils import maybe_allow_in_graph
27- from ..attention import FeedForward
28- from ..attention_processor import Attention
27+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
28+ from ..attention_dispatch import dispatch_attention_fn
2929from ..cache_utils import CacheMixin
3030from ..embeddings import PixArtAlphaTextProjection
3131from ..modeling_outputs import Transformer2DModelOutput
3737
3838
3939class LTXVideoAttentionProcessor2_0 :
40+ def __new__ (cls , * args , ** kwargs ):
41+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttentionProcessor2_0`"
42+ deprecate ("LTXVideoAttentionProcessor2_0" , "1.0.0" , deprecation_message )
43+
44+ return LTXAttnProcessor (* args , ** kwargs )
45+
46+
47+ class LTXAttnProcessor :
4048 r"""
41- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
42- used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
49+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
50+ model. It applies a normalization layer and rotary embedding on the query and key vector.
4351 """
4452
53+ _attention_backend = None
54+
4555 def __init__ (self ):
46- if not hasattr ( F , "scaled_dot_product_attention " ):
47- raise ImportError (
48- "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0 ."
56+ if is_torch_version ( "<" , "2.0 " ):
57+ raise ValueError (
58+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation ."
4959 )
5060
5161 def __call__ (
5262 self ,
53- attn : Attention ,
63+ attn : "LTXAttention" ,
5464 hidden_states : torch .Tensor ,
5565 encoder_hidden_states : Optional [torch .Tensor ] = None ,
5666 attention_mask : Optional [torch .Tensor ] = None ,
@@ -78,21 +88,90 @@ def __call__(
7888 query = apply_rotary_emb (query , image_rotary_emb )
7989 key = apply_rotary_emb (key , image_rotary_emb )
8090
81- query = query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
82- key = key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
83- value = value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
84-
85- hidden_states = F .scaled_dot_product_attention (
86- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
91+ query = query .unflatten (2 , (attn .heads , - 1 ))
92+ key = key .unflatten (2 , (attn .heads , - 1 ))
93+ value = value .unflatten (2 , (attn .heads , - 1 ))
94+
95+ hidden_states = dispatch_attention_fn (
96+ query ,
97+ key ,
98+ value ,
99+ attn_mask = attention_mask ,
100+ dropout_p = 0.0 ,
101+ is_causal = False ,
102+ backend = self ._attention_backend ,
87103 )
88- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
104+ hidden_states = hidden_states .flatten (2 , 3 )
89105 hidden_states = hidden_states .to (query .dtype )
90106
91107 hidden_states = attn .to_out [0 ](hidden_states )
92108 hidden_states = attn .to_out [1 ](hidden_states )
93109 return hidden_states
94110
95111
112+ class LTXAttention (torch .nn .Module , AttentionModuleMixin ):
113+ _default_processor_cls = LTXAttnProcessor
114+ _available_processors = [LTXAttnProcessor ]
115+
116+ def __init__ (
117+ self ,
118+ query_dim : int ,
119+ heads : int = 8 ,
120+ kv_heads : int = 8 ,
121+ dim_head : int = 64 ,
122+ dropout : float = 0.0 ,
123+ bias : bool = True ,
124+ cross_attention_dim : Optional [int ] = None ,
125+ out_bias : bool = True ,
126+ qk_norm : str = "rms_norm_across_heads" ,
127+ processor = None ,
128+ ):
129+ super ().__init__ ()
130+ if qk_norm != "rms_norm_across_heads" :
131+ raise NotImplementedError
132+
133+ self .head_dim = dim_head
134+ self .inner_dim = dim_head * heads
135+ self .inner_kv_dim = self .inner_dim if kv_heads is None else dim_head * kv_heads
136+ self .query_dim = query_dim
137+ self .cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
138+ self .use_bias = bias
139+ self .dropout = dropout
140+ self .out_dim = query_dim
141+ self .heads = heads
142+
143+ norm_eps , norm_elementwise_affine = 1e-5 , True
144+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = norm_eps , elementwise_affine = norm_elementwise_affine )
145+ self .norm_k = torch .nn .RMSNorm (dim_head * kv_heads , eps = norm_eps , elementwise_affine = norm_elementwise_affine )
146+ self .to_q = torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
147+ self .to_k = torch .nn .Linear (self .cross_attention_dim , self .inner_kv_dim , bias = bias )
148+ self .to_v = torch .nn .Linear (self .cross_attention_dim , self .inner_kv_dim , bias = bias )
149+ self .to_out = torch .nn .ModuleList ([])
150+ self .to_out .append (torch .nn .Linear (self .inner_dim , self .out_dim , bias = out_bias ))
151+ self .to_out .append (torch .nn .Dropout (dropout ))
152+
153+ if processor is None :
154+ processor = self ._default_processor_cls ()
155+ self .set_processor (processor )
156+
157+ def forward (
158+ self ,
159+ hidden_states : torch .Tensor ,
160+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
161+ attention_mask : Optional [torch .Tensor ] = None ,
162+ image_rotary_emb : Optional [torch .Tensor ] = None ,
163+ ** kwargs ,
164+ ) -> torch .Tensor :
165+ attn_parameters = set (inspect .signature (self .processor .__call__ ).parameters .keys ())
166+ unused_kwargs = [k for k , _ in kwargs .items () if k not in attn_parameters ]
167+ if len (unused_kwargs ) > 0 :
168+ logger .warning (
169+ f"joint_attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
170+ )
171+ kwargs = {k : w for k , w in kwargs .items () if k in attn_parameters }
172+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , image_rotary_emb , ** kwargs )
173+
174+
96175class LTXVideoRotaryPosEmbed (nn .Module ):
97176 def __init__ (
98177 self ,
@@ -231,7 +310,7 @@ def __init__(
231310 super ().__init__ ()
232311
233312 self .norm1 = RMSNorm (dim , eps = eps , elementwise_affine = elementwise_affine )
234- self .attn1 = Attention (
313+ self .attn1 = LTXAttention (
235314 query_dim = dim ,
236315 heads = num_attention_heads ,
237316 kv_heads = num_attention_heads ,
@@ -240,11 +319,10 @@ def __init__(
240319 cross_attention_dim = None ,
241320 out_bias = attention_out_bias ,
242321 qk_norm = qk_norm ,
243- processor = LTXVideoAttentionProcessor2_0 (),
244322 )
245323
246324 self .norm2 = RMSNorm (dim , eps = eps , elementwise_affine = elementwise_affine )
247- self .attn2 = Attention (
325+ self .attn2 = LTXAttention (
248326 query_dim = dim ,
249327 cross_attention_dim = cross_attention_dim ,
250328 heads = num_attention_heads ,
@@ -253,7 +331,6 @@ def __init__(
253331 bias = attention_bias ,
254332 out_bias = attention_out_bias ,
255333 qk_norm = qk_norm ,
256- processor = LTXVideoAttentionProcessor2_0 (),
257334 )
258335
259336 self .ff = FeedForward (dim , activation_fn = activation_fn )
@@ -299,7 +376,9 @@ def forward(
299376
300377
301378@maybe_allow_in_graph
302- class LTXVideoTransformer3DModel (ModelMixin , ConfigMixin , FromOriginalModelMixin , PeftAdapterMixin , CacheMixin ):
379+ class LTXVideoTransformer3DModel (
380+ ModelMixin , ConfigMixin , AttentionMixin , FromOriginalModelMixin , PeftAdapterMixin , CacheMixin
381+ ):
303382 r"""
304383 A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
305384
0 commit comments