@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
957957 return freqs_t , freqs_h , freqs_w , grid_t , grid_h , grid_w
958958
959959
960- def get_2d_rotary_pos_embed (embed_dim , crops_coords , grid_size , use_real = True ):
960+ def get_2d_rotary_pos_embed (
961+ embed_dim , crops_coords , grid_size , use_real = True , device : Optional [torch .device ] = None , output_type : str = "np"
962+ ):
963+ """
964+ RoPE for image tokens with 2d structure.
965+
966+ Args:
967+ embed_dim: (`int`):
968+ The embedding dimension size
969+ crops_coords (`Tuple[int]`)
970+ The top-left and bottom-right coordinates of the crop.
971+ grid_size (`Tuple[int]`):
972+ The grid size of the positional embedding.
973+ use_real (`bool`):
974+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
975+ device: (`torch.device`, **optional**):
976+ The device used to create tensors.
977+
978+ Returns:
979+ `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
980+ """
981+ if output_type == "np" :
982+ deprecation_message = (
983+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984+ " `from_numpy` is no longer required."
985+ " Pass `output_type='pt' to use the new version now."
986+ )
987+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
988+ return _get_2d_rotary_pos_embed_np (
989+ embed_dim = embed_dim ,
990+ crops_coords = crops_coords ,
991+ grid_size = grid_size ,
992+ use_real = use_real ,
993+ )
994+ start , stop = crops_coords
995+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
996+ grid_h = torch .linspace (
997+ start [0 ], stop [0 ] * (grid_size [0 ] - 1 ) / grid_size [0 ], grid_size [0 ], device = device , dtype = torch .float32
998+ )
999+ grid_w = torch .linspace (
1000+ start [1 ], stop [1 ] * (grid_size [1 ] - 1 ) / grid_size [1 ], grid_size [1 ], device = device , dtype = torch .float32
1001+ )
1002+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" )
1003+ grid = torch .stack (grid , dim = 0 ) # [2, W, H]
1004+
1005+ grid = grid .reshape ([2 , 1 , * grid .shape [1 :]])
1006+ pos_embed = get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = use_real )
1007+ return pos_embed
1008+
1009+
1010+ def _get_2d_rotary_pos_embed_np (embed_dim , crops_coords , grid_size , use_real = True ):
9611011 """
9621012 RoPE for image tokens with 2d structure.
9631013
0 commit comments