@@ -139,7 +139,13 @@ def get_3d_sincos_pos_embed(
139139
140140
141141def get_2d_sincos_pos_embed (
142- embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16
142+ embed_dim ,
143+ grid_size ,
144+ cls_token = False ,
145+ extra_tokens = 0 ,
146+ interpolation_scale = 1.0 ,
147+ base_size = 16 ,
148+ device : Optional [torch .device ] = None ,
143149):
144150 """
145151 Creates 2D sinusoidal positional embeddings.
@@ -157,22 +163,30 @@ def get_2d_sincos_pos_embed(
157163 The scale of the interpolation.
158164
159165 Returns:
160- pos_embed (`np.ndarray `):
166+ pos_embed (`torch.Tensor `):
161167 Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
162168 embed_dim]` if using cls_token
163169 """
164170 if isinstance (grid_size , int ):
165171 grid_size = (grid_size , grid_size )
166172
167- grid_h = np .arange (grid_size [0 ], dtype = np .float32 ) / (grid_size [0 ] / base_size ) / interpolation_scale
168- grid_w = np .arange (grid_size [1 ], dtype = np .float32 ) / (grid_size [1 ] / base_size ) / interpolation_scale
169- grid = np .meshgrid (grid_w , grid_h ) # here w goes first
170- grid = np .stack (grid , axis = 0 )
173+ grid_h = (
174+ torch .arange (grid_size [0 ], device = device , dtype = torch .float32 )
175+ / (grid_size [0 ] / base_size )
176+ / interpolation_scale
177+ )
178+ grid_w = (
179+ torch .arange (grid_size [1 ], device = device , dtype = torch .float32 )
180+ / (grid_size [1 ] / base_size )
181+ / interpolation_scale
182+ )
183+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" ) # here w goes first
184+ grid = torch .stack (grid , dim = 0 )
171185
172186 grid = grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
173187 pos_embed = get_2d_sincos_pos_embed_from_grid (embed_dim , grid )
174188 if cls_token and extra_tokens > 0 :
175- pos_embed = np . concatenate ([ np .zeros ([extra_tokens , embed_dim ]), pos_embed ], axis = 0 )
189+ pos_embed = torch . concat ([ torch .zeros ([extra_tokens , embed_dim ]), pos_embed ], dim = 0 )
176190 return pos_embed
177191
178192
@@ -182,10 +196,10 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
182196
183197 Args:
184198 embed_dim (`int`): The embedding dimension.
185- grid (`np.ndarray `): Grid of positions with shape `(H * W,)`.
199+ grid (`torch.Tensor `): Grid of positions with shape `(H * W,)`.
186200
187201 Returns:
188- `np.ndarray `: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
202+ `torch.Tensor `: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
189203 """
190204 if embed_dim % 2 != 0 :
191205 raise ValueError ("embed_dim must be divisible by 2" )
@@ -194,7 +208,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
194208 emb_h = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [0 ]) # (H*W, D/2)
195209 emb_w = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [1 ]) # (H*W, D/2)
196210
197- emb = np . concatenate ([emb_h , emb_w ], axis = 1 ) # (H*W, D)
211+ emb = torch . concat ([emb_h , emb_w ], dim = 1 ) # (H*W, D)
198212 return emb
199213
200214
@@ -204,25 +218,25 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
204218
205219 Args:
206220 embed_dim (`int`): The embedding dimension `D`
207- pos (`numpy.ndarray `): 1D tensor of positions with shape `(M,)`
221+ pos (`torch.Tensor `): 1D tensor of positions with shape `(M,)`
208222
209223 Returns:
210- `numpy.ndarray `: Sinusoidal positional embeddings of shape `(M, D)`.
224+ `torch.Tensor `: Sinusoidal positional embeddings of shape `(M, D)`.
211225 """
212226 if embed_dim % 2 != 0 :
213227 raise ValueError ("embed_dim must be divisible by 2" )
214228
215- omega = np .arange (embed_dim // 2 , dtype = np .float64 )
229+ omega = torch .arange (embed_dim // 2 , device = pos . device , dtype = torch .float64 )
216230 omega /= embed_dim / 2.0
217231 omega = 1.0 / 10000 ** omega # (D/2,)
218232
219233 pos = pos .reshape (- 1 ) # (M,)
220- out = np . einsum ( "m,d->md" , pos , omega ) # (M, D/2), outer product
234+ out = torch . outer ( pos , omega ) # (M, D/2), outer product
221235
222- emb_sin = np .sin (out ) # (M, D/2)
223- emb_cos = np .cos (out ) # (M, D/2)
236+ emb_sin = torch .sin (out ) # (M, D/2)
237+ emb_cos = torch .cos (out ) # (M, D/2)
224238
225- emb = np . concatenate ([emb_sin , emb_cos ], axis = 1 ) # (M, D)
239+ emb = torch . concat ([emb_sin , emb_cos ], dim = 1 ) # (M, D)
226240 return emb
227241
228242
@@ -291,7 +305,7 @@ def __init__(
291305 embed_dim , grid_size , base_size = self .base_size , interpolation_scale = self .interpolation_scale
292306 )
293307 persistent = True if pos_embed_max_size else False
294- self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ), persistent = persistent )
308+ self .register_buffer ("pos_embed" , pos_embed .float ().unsqueeze (0 ), persistent = persistent )
295309 else :
296310 raise ValueError (f"Unsupported pos_embed_type: { pos_embed_type } " )
297311
@@ -341,8 +355,9 @@ def forward(self, latent):
341355 grid_size = (height , width ),
342356 base_size = self .base_size ,
343357 interpolation_scale = self .interpolation_scale ,
358+ device = latent .device ,
344359 )
345- pos_embed = torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ). to ( latent . device )
360+ pos_embed = pos_embed .float ().unsqueeze (0 )
346361 else :
347362 pos_embed = self .pos_embed
348363
@@ -554,7 +569,7 @@ def __init__(
554569
555570 pos_embed = get_2d_sincos_pos_embed (hidden_size , pos_embed_max_size , base_size = pos_embed_max_size )
556571 pos_embed = pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
557- self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float (), persistent = False )
572+ self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
558573
559574 def forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) -> torch .Tensor :
560575 batch_size , channel , height , width = hidden_states .shape
0 commit comments