@@ -108,13 +108,21 @@ def __call__(
108108
109109 if rotary_emb is not None :
110110
111- def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
112- x_rotated = torch .view_as_complex (hidden_states .to (torch .float32 ).unflatten (3 , (- 1 , 2 )))
113- x_out = torch .view_as_real (x_rotated * freqs ).flatten (3 , 4 )
114- return x_out .type_as (hidden_states )
115-
116- query = apply_rotary_emb (query , rotary_emb )
117- key = apply_rotary_emb (key , rotary_emb )
111+ def apply_rotary_emb (
112+ hidden_states : torch .Tensor ,
113+ freqs_cos : torch .Tensor ,
114+ freqs_sin : torch .Tensor ,
115+ ):
116+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
117+ cos = freqs_cos [..., 0 ::2 ]
118+ sin = freqs_sin [..., 1 ::2 ]
119+ out = torch .empty_like (hidden_states )
120+ out [..., 0 ::2 ] = x1 * cos - x2 * sin
121+ out [..., 1 ::2 ] = x1 * sin + x2 * cos
122+ return out .type_as (hidden_states )
123+
124+ query = apply_rotary_emb (query , * rotary_emb )
125+ key = apply_rotary_emb (key , * rotary_emb )
118126
119127 # I2V task
120128 hidden_states_img = None
@@ -358,7 +366,11 @@ def forward(
358366
359367class SkyReelsV2RotaryPosEmbed (nn .Module ):
360368 def __init__ (
361- self , attention_head_dim : int , patch_size : Tuple [int , int , int ], max_seq_len : int , theta : float = 10000.0
369+ self ,
370+ attention_head_dim : int ,
371+ patch_size : Tuple [int , int , int ],
372+ max_seq_len : int ,
373+ theta : float = 10000.0 ,
362374 ):
363375 super ().__init__ ()
364376
@@ -368,35 +380,52 @@ def __init__(
368380
369381 h_dim = w_dim = 2 * (attention_head_dim // 6 )
370382 t_dim = attention_head_dim - h_dim - w_dim
383+ freqs_dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
384+
385+ freqs_cos = []
386+ freqs_sin = []
371387
372- freqs = []
373388 for dim in [t_dim , h_dim , w_dim ]:
374- freq = get_1d_rotary_pos_embed (
375- dim , max_seq_len , theta , use_real = False , repeat_interleave_real = False , freqs_dtype = torch .float32
389+ freq_cos , freq_sin = get_1d_rotary_pos_embed (
390+ dim ,
391+ max_seq_len ,
392+ theta ,
393+ use_real = True ,
394+ repeat_interleave_real = True ,
395+ freqs_dtype = freqs_dtype ,
376396 )
377- freqs .append (freq )
378- self .freqs = torch .cat (freqs , dim = 1 )
397+ freqs_cos .append (freq_cos )
398+ freqs_sin .append (freq_sin )
399+
400+ self .register_buffer ("freqs_cos" , torch .cat (freqs_cos , dim = 1 ), persistent = False )
401+ self .register_buffer ("freqs_sin" , torch .cat (freqs_sin , dim = 1 ), persistent = False )
379402
380403 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
381404 batch_size , num_channels , num_frames , height , width = hidden_states .shape
382405 p_t , p_h , p_w = self .patch_size
383406 ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
384407
385- freqs = self .freqs .to (hidden_states .device )
386- freqs = freqs .split_with_sizes (
387- [
388- self .attention_head_dim // 2 - 2 * (self .attention_head_dim // 6 ),
389- self .attention_head_dim // 6 ,
390- self .attention_head_dim // 6 ,
391- ],
392- dim = 1 ,
393- )
408+ split_sizes = [
409+ self .attention_head_dim - 2 * (self .attention_head_dim // 3 ),
410+ self .attention_head_dim // 3 ,
411+ self .attention_head_dim // 3 ,
412+ ]
413+
414+ freqs_cos = self .freqs_cos .split (split_sizes , dim = 1 )
415+ freqs_sin = self .freqs_sin .split (split_sizes , dim = 1 )
416+
417+ freqs_cos_f = freqs_cos [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
418+ freqs_cos_h = freqs_cos [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
419+ freqs_cos_w = freqs_cos [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
420+
421+ freqs_sin_f = freqs_sin [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
422+ freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
423+ freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
424+
425+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
426+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
394427
395- freqs_f = freqs [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
396- freqs_h = freqs [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
397- freqs_w = freqs [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
398- freqs = torch .cat ([freqs_f , freqs_h , freqs_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
399- return freqs
428+ return freqs_cos , freqs_sin
400429
401430
402431@maybe_allow_in_graph
0 commit comments