@@ -733,10 +733,11 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
733733 `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
734734 """
735735 start , stop = crops_coords
736- grid_h = np .linspace (start [0 ], stop [0 ], grid_size [0 ], endpoint = False , dtype = np .float32 )
737- grid_w = np .linspace (start [1 ], stop [1 ], grid_size [1 ], endpoint = False , dtype = np .float32 )
738- grid = np .meshgrid (grid_w , grid_h ) # here w goes first
739- grid = np .stack (grid , axis = 0 ) # [2, W, H]
736+ # scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
737+ grid_h = torch .linspace (start [0 ], stop [0 ] * (grid_size [0 ] - 1 ) / grid_size [0 ], grid_size [0 ], dtype = torch .float32 )
738+ grid_w = torch .linspace (start [1 ], stop [1 ] * (grid_size [1 ] - 1 ) / grid_size [1 ], grid_size [1 ], dtype = torch .float32 )
739+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" )
740+ grid = torch .stack (grid , dim = 0 ) # [2, W, H]
740741
741742 grid = grid .reshape ([2 , 1 , * grid .shape [1 :]])
742743 pos_embed = get_2d_rotary_pos_embed_from_grid (embed_dim , grid , use_real = use_real )
0 commit comments