@@ -957,7 +957,9 @@ def get_3d_rotary_pos_embed_allegro(
957957 return freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w
958958
959959
960- def get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True , device : Optional [torch .device ] = None ):
960+ def get_2d_rotary_pos_embed (
961+ embed_dim , crops_coords , grid_size , use_real = True , device : Optional [torch .device ] = None , output_type : str = "np"
962+ ):
961963 """
962964 RoPE for image tokens with 2d structure.
963965
@@ -976,6 +978,19 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, d
976978 Returns:
977979 `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
978980 """
981+ if output_type == "np" :
982+ deprecation_message = (
983+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984+ " `from_numpy` is no longer required."
985+ " Pass `output_type='pt' to use the new version now."
986+ )
987+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
988+ return _get_2d_rotary_pos_embed_np (
989+ embed_dim = embed_dim ,
990+ crops_coords = crops_coords ,
991+ grid_size = grid_size ,
992+ use_real = use_real ,
993+ )
979994 start , stop = crops_coords
980995 # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
981996 grid_h = torch .linspace (
@@ -992,6 +1007,34 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, d
9921007 return pos_embed
9931008
9941009
1010+ def _get_2d_rotary_pos_embed_np (embed_dim , crops_coords , grid_size , use_real = True ):
1011+ """
1012+ RoPE for image tokens with 2d structure.
1013+
1014+ Args:
1015+ embed_dim: (`int`):
1016+ The embedding dimension size
1017+ crops_coords (`Tuple[int]`)
1018+ The top-left and bottom-right coordinates of the crop.
1019+ grid_size (`Tuple[int]`):
1020+ The grid size of the positional embedding.
1021+ use_real (`bool`):
1022+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
1023+
1024+ Returns:
1025+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1026+ """
1027+ start , stop = crops_coords
1028+ grid_h = np .linspace (start [0 ], stop [0 ], grid_size [0 ], endpoint = False , dtype = np .float32 )
1029+ grid_w = np .linspace (start [1 ], stop [1 ], grid_size [1 ], endpoint = False , dtype = np .float32 )
1030+ grid = np .meshgrid (grid_w , grid_h ) # here w goes first
1031+ grid = np .stack (grid , axis = 0 ) # [2, W, H]
1032+
1033+ grid = grid .reshape ([2 , 1 , * grid .shape [1 :]])
1034+ pos_embed = get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = use_real )
1035+ return pos_embed
1036+
1037+
9951038def get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = False ):
9961039 """
9971040 Get 2D RoPE from grid.
0 commit comments