@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
594594 use_real : bool = True ,
595595 grid_type : str = "linspace" ,
596596 max_size : Optional [Tuple [int , int ]] = None ,
597+ device : Optional [torch .device ] = None ,
597598) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
598599 """
599600 RoPE for video tokens with 3D structure.
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
621622 if grid_type == "linspace" :
622623 start , stop = crops_coords
623624 grid_size_h , grid_size_w = grid_size
624- grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
625- grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
626- grid_t = np .arange (temporal_size , dtype = np .float32 )
627- grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
625+ grid_h = torch .linspace (
626+ start [0 ], stop [0 ] * (grid_size_h - 1 ) / grid_size_h , grid_size_h , device = device , dtype = torch .float32
627+ )
628+ grid_w = torch .linspace (
629+ start [1 ], stop [1 ] * (grid_size_w - 1 ) / grid_size_w , grid_size_w , device = device , dtype = torch .float32
630+ )
631+ grid_t = torch .arange (temporal_size , device = device , dtype = torch .float32 )
632+ grid_t = torch .linspace (
633+ 0 , temporal_size * (temporal_size - 1 ) / temporal_size , temporal_size , device = device , dtype = torch .float32
634+ )
628635 elif grid_type == "slice" :
629636 max_h , max_w = max_size
630637 grid_size_h , grid_size_w = grid_size
631- grid_h = np .arange (max_h , dtype = np .float32 )
632- grid_w = np .arange (max_w , dtype = np .float32 )
633- grid_t = np .arange (temporal_size , dtype = np .float32 )
638+ grid_h = torch .arange (max_h , device = device , dtype = torch .float32 )
639+ grid_w = torch .arange (max_w , device = device , dtype = torch .float32 )
640+ grid_t = torch .arange (temporal_size , device = device , dtype = torch .float32 )
634641 else :
635642 raise ValueError ("Invalid value passed for `grid_type`." )
636643
@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
640647 dim_w = embed_dim // 8 * 3
641648
642649 # Temporal frequencies
643- freqs_t = get_1d_rotary_pos_embed (dim_t , grid_t , use_real = True )
650+ freqs_t = get_1d_rotary_pos_embed (dim_t , grid_t , theta = theta , use_real = True )
644651 # Spatial frequencies for height and width
645- freqs_h = get_1d_rotary_pos_embed (dim_h , grid_h , use_real = True )
646- freqs_w = get_1d_rotary_pos_embed (dim_w , grid_w , use_real = True )
652+ freqs_h = get_1d_rotary_pos_embed (dim_h , grid_h , theta = theta , use_real = True )
653+ freqs_w = get_1d_rotary_pos_embed (dim_w , grid_w , theta = theta , use_real = True )
647654
648655 # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
649656 def combine_time_height_width (freqs_t , freqs_h , freqs_w ):
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
686693 temporal_size ,
687694 interpolation_scale : Tuple [float , float , float ] = (1.0 , 1.0 , 1.0 ),
688695 theta : int = 10000 ,
696+ device : Optional [torch .device ] = None ,
689697) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
690698 # TODO(aryan): docs
691699 start , stop = crops_coords
692700 grid_size_h , grid_size_w = grid_size
693701 interpolation_scale_t , interpolation_scale_h , interpolation_scale_w = interpolation_scale
694- grid_t = np .linspace (0 , temporal_size , temporal_size , endpoint = False , dtype = np .float32 )
695- grid_h = np .linspace (start [0 ], stop [0 ], grid_size_h , endpoint = False , dtype = np .float32 )
696- grid_w = np .linspace (start [1 ], stop [1 ], grid_size_w , endpoint = False , dtype = np .float32 )
702+ grid_t = torch .linspace (
703+ 0 , temporal_size * (temporal_size - 1 ) / temporal_size , temporal_size , device = device , dtype = torch .float32
704+ )
705+ grid_h = torch .linspace (
706+ start [0 ], stop [0 ] * (grid_size_h - 1 ) / grid_size_h , grid_size_h , device = device , dtype = torch .float32
707+ )
708+ grid_w = torch .linspace (
709+ start [1 ], stop [1 ] * (grid_size_w - 1 ) / grid_size_w , grid_size_w , device = device , dtype = torch .float32
710+ )
697711
698712 # Compute dimensions for each axis
699713 dim_t = embed_dim // 3
0 commit comments