@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
594594    use_real : bool  =  True ,
595595    grid_type : str  =  "linspace" ,
596596    max_size : Optional [Tuple [int , int ]] =  None ,
597+     device : Optional [torch .device ] =  None ,
597598) ->  Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
598599    """ 
599600    RoPE for video tokens with 3D structure. 
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
621622    if  grid_type  ==  "linspace" :
622623        start , stop  =  crops_coords 
623624        grid_size_h , grid_size_w  =  grid_size 
624-         grid_h  =  np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
625-         grid_w  =  np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
626-         grid_t  =  np .arange (temporal_size , dtype = np .float32 )
627-         grid_t  =  np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
625+         grid_h  =  torch .linspace (
626+             start [0 ], stop [0 ] *  (grid_size_h  -  1 ) /  grid_size_h , grid_size_h , device = device , dtype = torch .float32 
627+         )
628+         grid_w  =  torch .linspace (
629+             start [1 ], stop [1 ] *  (grid_size_w  -  1 ) /  grid_size_w , grid_size_w , device = device , dtype = torch .float32 
630+         )
631+         grid_t  =  torch .arange (temporal_size , device = device , dtype = torch .float32 )
632+         grid_t  =  torch .linspace (
633+             0 , temporal_size  *  (temporal_size  -  1 ) /  temporal_size , temporal_size , device = device , dtype = torch .float32 
634+         )
628635    elif  grid_type  ==  "slice" :
629636        max_h , max_w  =  max_size 
630637        grid_size_h , grid_size_w  =  grid_size 
631-         grid_h  =  np .arange (max_h , dtype = np .float32 )
632-         grid_w  =  np .arange (max_w , dtype = np .float32 )
633-         grid_t  =  np .arange (temporal_size , dtype = np .float32 )
638+         grid_h  =  torch .arange (max_h , device = device ,  dtype = torch .float32 )
639+         grid_w  =  torch .arange (max_w , device = device ,  dtype = torch .float32 )
640+         grid_t  =  torch .arange (temporal_size , device = device ,  dtype = torch .float32 )
634641    else :
635642        raise  ValueError ("Invalid value passed for `grid_type`." )
636643
@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
640647    dim_w  =  embed_dim  //  8  *  3 
641648
642649    # Temporal frequencies 
643-     freqs_t  =  get_1d_rotary_pos_embed (dim_t , grid_t , use_real = True )
650+     freqs_t  =  get_1d_rotary_pos_embed (dim_t , grid_t , theta = theta ,  use_real = True )
644651    # Spatial frequencies for height and width 
645-     freqs_h  =  get_1d_rotary_pos_embed (dim_h , grid_h , use_real = True )
646-     freqs_w  =  get_1d_rotary_pos_embed (dim_w , grid_w , use_real = True )
652+     freqs_h  =  get_1d_rotary_pos_embed (dim_h , grid_h , theta = theta ,  use_real = True )
653+     freqs_w  =  get_1d_rotary_pos_embed (dim_w , grid_w , theta = theta ,  use_real = True )
647654
648655    # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor 
649656    def  combine_time_height_width (freqs_t , freqs_h , freqs_w ):
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
686693    temporal_size ,
687694    interpolation_scale : Tuple [float , float , float ] =  (1.0 , 1.0 , 1.0 ),
688695    theta : int  =  10000 ,
696+     device : Optional [torch .device ] =  None ,
689697) ->  Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
690698    # TODO(aryan): docs 
691699    start , stop  =  crops_coords 
692700    grid_size_h , grid_size_w  =  grid_size 
693701    interpolation_scale_t , interpolation_scale_h , interpolation_scale_w  =  interpolation_scale 
694-     grid_t  =  np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
695-     grid_h  =  np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
696-     grid_w  =  np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
702+     grid_t  =  torch .linspace (
703+         0 , temporal_size  *  (temporal_size  -  1 ) /  temporal_size , temporal_size , device = device , dtype = torch .float32 
704+     )
705+     grid_h  =  torch .linspace (
706+         start [0 ], stop [0 ] *  (grid_size_h  -  1 ) /  grid_size_h , grid_size_h , device = device , dtype = torch .float32 
707+     )
708+     grid_w  =  torch .linspace (
709+         start [1 ], stop [1 ] *  (grid_size_w  -  1 ) /  grid_size_w , grid_size_w , device = device , dtype = torch .float32 
710+     )
697711
698712    # Compute dimensions for each axis 
699713    dim_t  =  embed_dim  //  3 
0 commit comments