Skip to content

Commit 8c8b44a

Browse files
committed
reviewer feedback.
1 parent 532711f commit 8c8b44a

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __new__(cls, *args, **kwargs):
4141
deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
4242
deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
4343

44-
return LTXAttnProcessor(*args, **kwargs)
44+
return LTXVideoAttnProcessor(*args, **kwargs)
4545

4646

4747
class LTXVideoAttnProcessor:
@@ -110,8 +110,8 @@ def __call__(
110110

111111

112112
class LTXAttention(torch.nn.Module, AttentionModuleMixin):
113-
_default_processor_cls = LTXAttnProcessor
114-
_available_processors = [LTXAttnProcessor]
113+
_default_processor_cls = LTXVideoAttnProcessor
114+
_available_processors = [LTXVideoAttnProcessor]
115115

116116
def __init__(
117117
self,
@@ -128,7 +128,7 @@ def __init__(
128128
):
129129
super().__init__()
130130
if qk_norm != "rms_norm_across_heads":
131-
raise NotImplementedError
131+
raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
132132

133133
self.head_dim = dim_head
134134
self.inner_dim = dim_head * heads
@@ -140,7 +140,8 @@ def __init__(
140140
self.out_dim = query_dim
141141
self.heads = heads
142142

143-
norm_eps, norm_elementwise_affine = 1e-5, True
143+
norm_eps = 1e-5
144+
norm_elementwise_affine = True
144145
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
145146
self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
146147
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
@@ -166,7 +167,7 @@ def forward(
166167
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
167168
if len(unused_kwargs) > 0:
168169
logger.warning(
169-
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
170+
f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
170171
)
171172
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
172173
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)

0 commit comments

Comments
 (0)