Skip to content

Commit 461ab73

Browse files
committed
Corrected einops removal
1 parent 4383175 commit 461ab73

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5155,11 +5155,11 @@ def __call__(
51555155
ip_value = self.to_v_ip(norm_ip_hidden_states)
51565156

51575157
# Reshape
5158-
img_query = img_query.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5159-
img_key = img_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5160-
img_value = img_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5161-
ip_key = ip_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5162-
ip_value = ip_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5158+
img_query = img_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5159+
img_key = img_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5160+
img_value = img_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5161+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
5162+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
51635163

51645164
# Norm
51655165
img_query = self.norm_q(img_query)
@@ -5171,7 +5171,7 @@ def __call__(
51715171
img_value = torch.cat([img_value, ip_value], dim=2)
51725172

51735173
ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False)
5174-
ip_hidden_states = ip_hidden_states.transpose(1,2).view(batch_size, head_dim, -1)
5174+
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
51755175
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
51765176

51775177
hidden_states = hidden_states + ip_hidden_states * self.scale

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,8 @@ def __call__(
994994
timestep=timestep,
995995
encoder_hidden_states=prompt_embeds,
996996
pooled_projections=pooled_prompt_embeds,
997-
return_dict=False,
998997
joint_attention_kwargs=self.joint_attention_kwargs,
998+
return_dict=False,
999999
)[0]
10001000

10011001
# perform guidance

0 commit comments

Comments
 (0)