Skip to content

Commit 301711b

Browse files
committed
Fix processor problem
1 parent 90ce00f commit 301711b

File tree

1 file changed

+17
-29
lines changed

1 file changed

+17
-29
lines changed

src/diffusers/models/transformers/transformer_easyanimate.py

100644100755
Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ class EasyAnimateAttnProcessor2_0:
7373
used in the EasyAnimateTransformer3DModel model.
7474
"""
7575

76-
def __init__(self, attn2=None):
77-
self.attn2 = attn2
76+
def __init__(self):
7877
if not hasattr(F, "scaled_dot_product_attention"):
7978
raise ImportError(
8079
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
@@ -84,11 +83,11 @@ def __call__(
8483
self,
8584
attn: Attention,
8685
hidden_states: torch.Tensor,
87-
encoder_hidden_states: Optional[torch.Tensor] = None,
86+
encoder_hidden_states: torch.Tensor,
8887
attention_mask: Optional[torch.Tensor] = None,
8988
image_rotary_emb: Optional[torch.Tensor] = None,
9089
) -> torch.Tensor:
91-
if self.attn2 is None and encoder_hidden_states is not None:
90+
if attn.add_q_proj is None and encoder_hidden_states is not None:
9291
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
9392

9493
# 1. QKV projections
@@ -107,19 +106,19 @@ def __call__(
107106
key = attn.norm_k(key)
108107

109108
# 3. Encoder condition QKV projection and normalization
110-
if self.attn2.to_q is not None and encoder_hidden_states is not None:
111-
encoder_query = self.attn2.to_q(encoder_hidden_states)
112-
encoder_key = self.attn2.to_k(encoder_hidden_states)
113-
encoder_value = self.attn2.to_v(encoder_hidden_states)
109+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
110+
encoder_query = attn.add_q_proj(encoder_hidden_states)
111+
encoder_key = attn.add_k_proj(encoder_hidden_states)
112+
encoder_value = attn.add_v_proj(encoder_hidden_states)
114113

115114
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
116115
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
117116
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
118117

119-
if self.attn2.norm_q is not None:
120-
encoder_query = self.attn2.norm_q(encoder_query)
121-
if self.attn2.norm_k is not None:
122-
encoder_key = self.attn2.norm_k(encoder_key)
118+
if attn.norm_added_q is not None:
119+
encoder_query = attn.norm_added_q(encoder_query)
120+
if attn.norm_added_k is not None:
121+
encoder_key = attn.norm_added_k(encoder_key)
123122

124123
query = torch.cat([encoder_query, query], dim=2)
125124
key = torch.cat([encoder_key, key], dim=2)
@@ -154,9 +153,8 @@ def __call__(
154153
hidden_states = attn.to_out[0](hidden_states)
155154
hidden_states = attn.to_out[1](hidden_states)
156155

157-
if self.attn2 is not None and getattr(self.attn2, "to_out", None) is not None:
158-
encoder_hidden_states = self.attn2.to_out[0](encoder_hidden_states)
159-
encoder_hidden_states = self.attn2.to_out[1](encoder_hidden_states)
156+
if getattr(attn, "to_add_out", None) is not None:
157+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
160158
else:
161159
if getattr(attn, "to_out", None) is not None:
162160
hidden_states = attn.to_out[0](hidden_states)
@@ -192,27 +190,17 @@ def __init__(
192190
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
193191
)
194192

195-
if is_mmdit_block:
196-
self.attn2 = Attention(
197-
query_dim=dim,
198-
dim_head=attention_head_dim,
199-
heads=num_attention_heads,
200-
qk_norm="layer_norm" if qk_norm else None,
201-
eps=1e-6,
202-
bias=True,
203-
processor=EasyAnimateAttnProcessor2_0(),
204-
)
205-
else:
206-
self.attn2 = None
207-
208193
self.attn1 = Attention(
209194
query_dim=dim,
210195
dim_head=attention_head_dim,
211196
heads=num_attention_heads,
212197
qk_norm="layer_norm" if qk_norm else None,
213198
eps=1e-6,
214199
bias=True,
215-
processor=EasyAnimateAttnProcessor2_0(self.attn2),
200+
added_proj_bias=True,
201+
added_kv_proj_dim=dim if is_mmdit_block else None,
202+
context_pre_only=False if is_mmdit_block else None,
203+
processor=EasyAnimateAttnProcessor2_0(),
216204
)
217205

218206
# FFN Part

0 commit comments

Comments
 (0)