@@ -41,7 +41,7 @@ def __new__(cls, *args, **kwargs):
4141 deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
4242 deprecate ("LTXVideoAttentionProcessor2_0" , "1.0.0" , deprecation_message )
4343
44- return LTXAttnProcessor (* args , ** kwargs )
44+ return LTXVideoAttnProcessor (* args , ** kwargs )
4545
4646
4747class LTXVideoAttnProcessor :
@@ -110,8 +110,8 @@ def __call__(
110110
111111
112112class LTXAttention (torch .nn .Module , AttentionModuleMixin ):
113- _default_processor_cls = LTXAttnProcessor
114- _available_processors = [LTXAttnProcessor ]
113+ _default_processor_cls = LTXVideoAttnProcessor
114+ _available_processors = [LTXVideoAttnProcessor ]
115115
116116 def __init__ (
117117 self ,
@@ -128,7 +128,7 @@ def __init__(
128128 ):
129129 super ().__init__ ()
130130 if qk_norm != "rms_norm_across_heads" :
131- raise NotImplementedError
131+ raise NotImplementedError ( "Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`." )
132132
133133 self .head_dim = dim_head
134134 self .inner_dim = dim_head * heads
@@ -140,7 +140,8 @@ def __init__(
140140 self .out_dim = query_dim
141141 self .heads = heads
142142
143- norm_eps , norm_elementwise_affine = 1e-5 , True
143+ norm_eps = 1e-5
144+ norm_elementwise_affine = True
144145 self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = norm_eps , elementwise_affine = norm_elementwise_affine )
145146 self .norm_k = torch .nn .RMSNorm (dim_head * kv_heads , eps = norm_eps , elementwise_affine = norm_elementwise_affine )
146147 self .to_q = torch .nn .Linear (query_dim , self .inner_dim , bias = bias )
@@ -166,7 +167,7 @@ def forward(
166167 unused_kwargs = [k for k , _ in kwargs .items () if k not in attn_parameters ]
167168 if len (unused_kwargs ) > 0 :
168169 logger .warning (
169- f"joint_attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
170+ f"attention_kwargs { unused_kwargs } are not expected by { self .processor .__class__ .__name__ } and will be ignored."
170171 )
171172 kwargs = {k : w for k , w in kwargs .items () if k in attn_parameters }
172173 return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , image_rotary_emb , ** kwargs )
0 commit comments