@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
957957    return  freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w 
958958
959959
960- def  get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True ):
960+ def  get_2d_rotary_pos_embed (
961+     embed_dim , crops_coords , grid_size , use_real = True , device : Optional [torch .device ] =  None , output_type : str  =  "np" 
962+ ):
963+     """ 
964+     RoPE for image tokens with 2d structure. 
965+ 
966+     Args: 
967+     embed_dim: (`int`): 
968+         The embedding dimension size 
969+     crops_coords (`Tuple[int]`) 
970+         The top-left and bottom-right coordinates of the crop. 
971+     grid_size (`Tuple[int]`): 
972+         The grid size of the positional embedding. 
973+     use_real (`bool`): 
974+         If True, return real part and imaginary part separately. Otherwise, return complex numbers. 
975+     device: (`torch.device`, **optional**): 
976+         The device used to create tensors. 
977+ 
978+     Returns: 
979+         `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. 
980+     """ 
981+     if  output_type  ==  "np" :
982+         deprecation_message  =  (
983+             "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." 
984+             " `from_numpy` is no longer required." 
985+             "  Pass `output_type='pt' to use the new version now." 
986+         )
987+         deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
988+         return  _get_2d_rotary_pos_embed_np (
989+             embed_dim = embed_dim ,
990+             crops_coords = crops_coords ,
991+             grid_size = grid_size ,
992+             use_real = use_real ,
993+         )
994+     start , stop  =  crops_coords 
995+     # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False) 
996+     grid_h  =  torch .linspace (
997+         start [0 ], stop [0 ] *  (grid_size [0 ] -  1 ) /  grid_size [0 ], grid_size [0 ], device = device , dtype = torch .float32 
998+     )
999+     grid_w  =  torch .linspace (
1000+         start [1 ], stop [1 ] *  (grid_size [1 ] -  1 ) /  grid_size [1 ], grid_size [1 ], device = device , dtype = torch .float32 
1001+     )
1002+     grid  =  torch .meshgrid (grid_w , grid_h , indexing = "xy" )
1003+     grid  =  torch .stack (grid , dim = 0 )  # [2, W, H] 
1004+ 
1005+     grid  =  grid .reshape ([2 , 1 , * grid .shape [1 :]])
1006+     pos_embed  =  get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = use_real )
1007+     return  pos_embed 
1008+ 
1009+ 
1010+ def  _get_2d_rotary_pos_embed_np (embed_dim , crops_coords , grid_size , use_real = True ):
9611011    """ 
9621012    RoPE for image tokens with 2d structure. 
9631013
0 commit comments