@@ -5155,11 +5155,11 @@ def __call__(
51555155            ip_value  =  self .to_v_ip (norm_ip_hidden_states )
51565156
51575157            # Reshape 
5158-             img_query  =  img_query .view (batch_size , head_dim , attn .heads , - 1 ).transpose (1 ,2 )             
5159-             img_key  =  img_key .view (batch_size , head_dim , attn .heads , - 1 ).transpose (1 ,2 )
5160-             img_value  =  img_value .view (batch_size , head_dim , attn .heads , - 1 ).transpose (1 ,2 )
5161-             ip_key  =  ip_key .view (batch_size , head_dim , attn .heads , - 1 ).transpose (1 ,2 )
5162-             ip_value  =  ip_value .view (batch_size , head_dim , attn .heads , - 1 ).transpose (1 ,2 )
5158+             img_query  =  img_query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 ,  2 ) 
5159+             img_key  =  img_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 ,  2 )
5160+             img_value  =  img_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 ,  2 )
5161+             ip_key  =  ip_key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 ,  2 )
5162+             ip_value  =  ip_value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 ,  2 )
51635163
51645164            # Norm 
51655165            img_query  =  self .norm_q (img_query )
@@ -5171,7 +5171,7 @@ def __call__(
51715171            img_value  =  torch .cat ([img_value , ip_value ], dim = 2 )
51725172
51735173            ip_hidden_states  =  F .scaled_dot_product_attention (img_query , img_key , img_value , dropout_p = 0.0 , is_causal = False )
5174-             ip_hidden_states  =  ip_hidden_states .transpose (1 ,2 ).view (batch_size , head_dim ,  - 1 )
5174+             ip_hidden_states  =  ip_hidden_states .transpose (1 ,  2 ).view (batch_size , - 1 ,  attn . heads   *   head_dim )
51755175            ip_hidden_states  =  ip_hidden_states .to (img_query .dtype )
51765176
51775177            hidden_states  =  hidden_states  +  ip_hidden_states  *  self .scale 
0 commit comments