@@ -680,6 +680,31 @@ def dtype(self) -> torch.dtype:
680680 def device (self ) -> torch .device :
681681 return self .patch_embed .proj .weight .device
682682
683+ def rot_pos_emb (self , grid_thw : torch .Tensor ) -> torch .Tensor :
684+ pos_ids = []
685+ for t , h , w in grid_thw :
686+ hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
687+ wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
688+ hpos_ids = hpos_ids .reshape (
689+ h // self .spatial_merge_size ,
690+ self .spatial_merge_size ,
691+ w // self .spatial_merge_size ,
692+ self .spatial_merge_size ,
693+ ).permute (0 , 2 , 1 , 3 ).flatten ()
694+ wpos_ids = wpos_ids .reshape (
695+ h // self .spatial_merge_size ,
696+ self .spatial_merge_size ,
697+ w // self .spatial_merge_size ,
698+ self .spatial_merge_size ,
699+ ).permute (0 , 2 , 1 , 3 ).flatten ()
700+ pos_ids .append (
701+ torch .stack ([hpos_ids , wpos_ids ], dim = - 1 ).repeat (t , 1 ))
702+ pos_ids = torch .cat (pos_ids , dim = 0 )
703+ max_grid_size = grid_thw [:, 1 :].max ()
704+ rotary_pos_emb_full = self .rotary_pos_emb (max_grid_size )
705+ rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
706+ return rotary_pos_emb
707+
683708 def rotary_pos_emb_thw (self , t , h , w ):
684709 hpos_ids = torch .arange (h ).unsqueeze (1 ).expand (- 1 , w )
685710 wpos_ids = torch .arange (w ).unsqueeze (0 ).expand (h , - 1 )
0 commit comments