@@ -139,7 +139,13 @@ def get_3d_sincos_pos_embed(
139139
140140
141141def  get_2d_sincos_pos_embed (
142-     embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16 
142+     embed_dim ,
143+     grid_size ,
144+     cls_token = False ,
145+     extra_tokens = 0 ,
146+     interpolation_scale = 1.0 ,
147+     base_size = 16 ,
148+     device : Optional [torch .device ] =  None ,
143149):
144150    """ 
145151    Creates 2D sinusoidal positional embeddings. 
@@ -157,22 +163,30 @@ def get_2d_sincos_pos_embed(
157163            The scale of the interpolation. 
158164
159165    Returns: 
160-         pos_embed (`np.ndarray `): 
166+         pos_embed (`torch.Tensor `): 
161167            Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, 
162168            embed_dim]` if using cls_token 
163169    """ 
164170    if  isinstance (grid_size , int ):
165171        grid_size  =  (grid_size , grid_size )
166172
167-     grid_h  =  np .arange (grid_size [0 ], dtype = np .float32 ) /  (grid_size [0 ] /  base_size ) /  interpolation_scale 
168-     grid_w  =  np .arange (grid_size [1 ], dtype = np .float32 ) /  (grid_size [1 ] /  base_size ) /  interpolation_scale 
169-     grid  =  np .meshgrid (grid_w , grid_h )  # here w goes first 
170-     grid  =  np .stack (grid , axis = 0 )
173+     grid_h  =  (
174+         torch .arange (grid_size [0 ], device = device , dtype = torch .float32 )
175+         /  (grid_size [0 ] /  base_size )
176+         /  interpolation_scale 
177+     )
178+     grid_w  =  (
179+         torch .arange (grid_size [1 ], device = device , dtype = torch .float32 )
180+         /  (grid_size [1 ] /  base_size )
181+         /  interpolation_scale 
182+     )
183+     grid  =  torch .meshgrid (grid_w , grid_h , indexing = "xy" )  # here w goes first 
184+     grid  =  torch .stack (grid , dim = 0 )
171185
172186    grid  =  grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
173187    pos_embed  =  get_2d_sincos_pos_embed_from_grid (embed_dim , grid )
174188    if  cls_token  and  extra_tokens  >  0 :
175-         pos_embed  =  np . concatenate ([ np .zeros ([extra_tokens , embed_dim ]), pos_embed ], axis = 0 )
189+         pos_embed  =  torch . concat ([ torch .zeros ([extra_tokens , embed_dim ]), pos_embed ], dim = 0 )
176190    return  pos_embed 
177191
178192
@@ -182,10 +196,10 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
182196
183197    Args: 
184198        embed_dim (`int`): The embedding dimension. 
185-         grid (`np.ndarray `): Grid of positions with shape `(H * W,)`. 
199+         grid (`torch.Tensor `): Grid of positions with shape `(H * W,)`. 
186200
187201    Returns: 
188-         `np.ndarray `: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` 
202+         `torch.Tensor `: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` 
189203    """ 
190204    if  embed_dim  %  2  !=  0 :
191205        raise  ValueError ("embed_dim must be divisible by 2" )
@@ -194,7 +208,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
194208    emb_h  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [0 ])  # (H*W, D/2) 
195209    emb_w  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [1 ])  # (H*W, D/2) 
196210
197-     emb  =  np . concatenate ([emb_h , emb_w ], axis = 1 )  # (H*W, D) 
211+     emb  =  torch . concat ([emb_h , emb_w ], dim = 1 )  # (H*W, D) 
198212    return  emb 
199213
200214
@@ -204,25 +218,25 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
204218
205219    Args: 
206220        embed_dim (`int`): The embedding dimension `D` 
207-         pos (`numpy.ndarray `): 1D tensor of positions with shape `(M,)` 
221+         pos (`torch.Tensor `): 1D tensor of positions with shape `(M,)` 
208222
209223    Returns: 
210-         `numpy.ndarray `: Sinusoidal positional embeddings of shape `(M, D)`. 
224+         `torch.Tensor `: Sinusoidal positional embeddings of shape `(M, D)`. 
211225    """ 
212226    if  embed_dim  %  2  !=  0 :
213227        raise  ValueError ("embed_dim must be divisible by 2" )
214228
215-     omega  =  np .arange (embed_dim  //  2 , dtype = np .float64 )
229+     omega  =  torch .arange (embed_dim  //  2 , device = pos . device ,  dtype = torch .float64 )
216230    omega  /=  embed_dim  /  2.0 
217231    omega  =  1.0  /  10000 ** omega   # (D/2,) 
218232
219233    pos  =  pos .reshape (- 1 )  # (M,) 
220-     out  =  np . einsum ( "m,d->md" ,  pos , omega )  # (M, D/2), outer product 
234+     out  =  torch . outer ( pos , omega )  # (M, D/2), outer product 
221235
222-     emb_sin  =  np .sin (out )  # (M, D/2) 
223-     emb_cos  =  np .cos (out )  # (M, D/2) 
236+     emb_sin  =  torch .sin (out )  # (M, D/2) 
237+     emb_cos  =  torch .cos (out )  # (M, D/2) 
224238
225-     emb  =  np . concatenate ([emb_sin , emb_cos ], axis = 1 )  # (M, D) 
239+     emb  =  torch . concat ([emb_sin , emb_cos ], dim = 1 )  # (M, D) 
226240    return  emb 
227241
228242
@@ -291,7 +305,7 @@ def __init__(
291305                embed_dim , grid_size , base_size = self .base_size , interpolation_scale = self .interpolation_scale 
292306            )
293307            persistent  =  True  if  pos_embed_max_size  else  False 
294-             self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ), persistent = persistent )
308+             self .register_buffer ("pos_embed" , pos_embed .float ().unsqueeze (0 ), persistent = persistent )
295309        else :
296310            raise  ValueError (f"Unsupported pos_embed_type: { pos_embed_type }  )
297311
@@ -341,8 +355,9 @@ def forward(self, latent):
341355                    grid_size = (height , width ),
342356                    base_size = self .base_size ,
343357                    interpolation_scale = self .interpolation_scale ,
358+                     device = latent .device ,
344359                )
345-                 pos_embed  =  torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ). to ( latent . device )
360+                 pos_embed  =  pos_embed .float ().unsqueeze (0 )
346361            else :
347362                pos_embed  =  self .pos_embed 
348363
@@ -554,7 +569,7 @@ def __init__(
554569
555570        pos_embed  =  get_2d_sincos_pos_embed (hidden_size , pos_embed_max_size , base_size = pos_embed_max_size )
556571        pos_embed  =  pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
557-         self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float (), persistent = False )
572+         self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
558573
559574    def  forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) ->  torch .Tensor :
560575        batch_size , channel , height , width  =  hidden_states .shape 
0 commit comments