Skip to content

Commit 5aed1d3

Browse files
committed
Removed usage of einops
1 parent d868ddb commit 5aed1d3

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import torch
1919
import torch.nn.functional as F
2020
from torch import nn
21-
from einops import rearrange
2221

2322
from ..image_processor import IPAdapterMaskProcessor
2423
from ..utils import deprecate, is_torch_xla_available, logging
@@ -5156,11 +5155,11 @@ def __call__(
51565155
ip_value = self.to_v_ip(norm_ip_hidden_states)
51575156

51585157
# Reshape
5159-
img_query = rearrange(img_query, 'b l (h d) -> b h l d', h=attn.heads)
5160-
img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
5161-
img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
5162-
ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
5163-
ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
5158+
img_query = img_query.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5159+
img_key = img_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5160+
img_value = img_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5161+
ip_key = ip_key.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
5162+
ip_value = ip_value.view(batch_size, head_dim, attn.heads, -1).transpose(1,2)
51645163

51655164
# Norm
51665165
img_query = self.norm_q(img_query)
@@ -5172,7 +5171,7 @@ def __call__(
51725171
img_value = torch.cat([img_value, ip_value], dim=2)
51735172

51745173
ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False)
5175-
ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
5174+
ip_hidden_states = ip_hidden_states.transpose(1,2).view(batch_size, head_dim, -1)
51765175
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
51775176

51785177
hidden_states = hidden_states + ip_hidden_states * self.scale

0 commit comments

Comments
 (0)