1818import torch
1919import torch .nn .functional as F
2020from torch import nn
21- from einops import rearrange
2221
2322from ..image_processor import IPAdapterMaskProcessor
2423from ..utils import deprecate , is_torch_xla_available , logging
@@ -5156,11 +5155,11 @@ def __call__(
51565155 ip_value = self .to_v_ip (norm_ip_hidden_states )
51575156
51585157 # Reshape
5159- img_query = rearrange ( img_query , 'b l (h d) -> b h l d' , h = attn . heads )
5160- img_key = rearrange ( img_key , 'b l (h d) -> b h l d' , h = attn .heads )
5161- img_value = rearrange ( img_value , 'b l (h d) -> b h l d' , h = attn .heads )
5162- ip_key = rearrange ( ip_key , 'b l (h d) -> b h l d' , h = attn .heads )
5163- ip_value = rearrange ( ip_value , 'b l (h d) -> b h l d' , h = attn .heads )
5158+ img_query = img_query . view ( batch_size , head_dim , attn . heads , - 1 ). transpose ( 1 , 2 )
5159+ img_key = img_key . view ( batch_size , head_dim , attn .heads , - 1 ). transpose ( 1 , 2 )
5160+ img_value = img_value . view ( batch_size , head_dim , attn .heads , - 1 ). transpose ( 1 , 2 )
5161+ ip_key = ip_key . view ( batch_size , head_dim , attn .heads , - 1 ). transpose ( 1 , 2 )
5162+ ip_value = ip_value . view ( batch_size , head_dim , attn .heads , - 1 ). transpose ( 1 , 2 )
51645163
51655164 # Norm
51665165 img_query = self .norm_q (img_query )
@@ -5172,7 +5171,7 @@ def __call__(
51725171 img_value = torch .cat ([img_value , ip_value ], dim = 2 )
51735172
51745173 ip_hidden_states = F .scaled_dot_product_attention (img_query , img_key , img_value , dropout_p = 0.0 , is_causal = False )
5175- ip_hidden_states = rearrange ( ip_hidden_states , 'b h l d -> b l (h d)' )
5174+ ip_hidden_states = ip_hidden_states . transpose ( 1 , 2 ). view ( batch_size , head_dim , - 1 )
51765175 ip_hidden_states = ip_hidden_states .to (img_query .dtype )
51775176
51785177 hidden_states = hidden_states + ip_hidden_states * self .scale
0 commit comments