@@ -93,9 +93,9 @@ def __call__(
9393 query = attn .norm_q (query )
9494 key = attn .norm_k (key )
9595
96- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
97- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
98- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
96+ query = query .unflatten (2 , (attn .heads , - 1 ))
97+ key = key .unflatten (2 , (attn .heads , - 1 ))
98+ value = value .unflatten (2 , (attn .heads , - 1 ))
9999
100100 if rotary_emb is not None :
101101
@@ -104,8 +104,7 @@ def apply_rotary_emb(
104104 freqs_cos : torch .Tensor ,
105105 freqs_sin : torch .Tensor ,
106106 ):
107- x = hidden_states .view (* hidden_states .shape [:- 1 ], - 1 , 2 )
108- x1 , x2 = x [..., 0 ], x [..., 1 ]
107+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
109108 cos = freqs_cos [..., 0 ::2 ]
110109 sin = freqs_sin [..., 1 ::2 ]
111110 out = torch .empty_like (hidden_states )
@@ -122,8 +121,8 @@ def apply_rotary_emb(
122121 key_img , value_img = _get_added_kv_projections (attn , encoder_hidden_states_img )
123122 key_img = attn .norm_added_k (key_img )
124123
125- key_img = key_img .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
126- value_img = value_img .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
124+ key_img = key_img .unflatten (2 , (attn .heads , - 1 ))
125+ value_img = value_img .unflatten (2 , (attn .heads , - 1 ))
127126
128127 hidden_states_img = dispatch_attention_fn (
129128 query ,
@@ -134,7 +133,7 @@ def apply_rotary_emb(
134133 is_causal = False ,
135134 backend = self ._attention_backend ,
136135 )
137- hidden_states_img = hidden_states_img .transpose ( 1 , 2 ). flatten (2 , 3 )
136+ hidden_states_img = hidden_states_img .flatten (2 , 3 )
138137 hidden_states_img = hidden_states_img .type_as (query )
139138
140139 hidden_states = dispatch_attention_fn (
@@ -146,7 +145,7 @@ def apply_rotary_emb(
146145 is_causal = False ,
147146 backend = self ._attention_backend ,
148147 )
149- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
148+ hidden_states = hidden_states .flatten (2 , 3 )
150149 hidden_states = hidden_states .type_as (query )
151150
152151 if hidden_states_img is not None :
@@ -395,8 +394,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
395394 freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
396395 freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
397396
398- freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
399- freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
397+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
398+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
400399
401400 return freqs_cos , freqs_sin
402401
0 commit comments