@@ -85,6 +85,7 @@ def get_3d_sincos_pos_embed(
8585 spatial_interpolation_scale : float = 1.0 ,
8686 temporal_interpolation_scale : float = 1.0 ,
8787 device : Optional [torch .device ] = None ,
88+ output_type : str = "np" ,
8889) -> torch .Tensor :
8990 r"""
9091 Creates 3D sinusoidal positional embeddings.
@@ -107,6 +108,20 @@ def get_3d_sincos_pos_embed(
107108 The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
108109 embed_dim]`.
109110 """
111+ if output_type == "np" :
112+ deprecation_message = (
113+ "`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
114+ " `from_numpy` is no longer required."
115+ " Pass `output_type='pt' to use the new version now."
116+ )
117+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
118+ return get_3d_sincos_pos_embed_np (
119+ embed_dim = embed_dim ,
120+ spatial_size = spatial_size ,
121+ temporal_size = temporal_size ,
122+ spatial_interpolation_scale = spatial_interpolation_scale ,
123+ temporal_interpolation_scale = temporal_interpolation_scale ,
124+ )
110125 if embed_dim % 4 != 0 :
111126 raise ValueError ("`embed_dim` must be divisible by 4" )
112127 if isinstance (spatial_size , int ):
@@ -122,11 +137,11 @@ def get_3d_sincos_pos_embed(
122137 grid = torch .stack (grid , dim = 0 )
123138
124139 grid = grid .reshape ([2 , 1 , spatial_size [1 ], spatial_size [0 ]])
125- pos_embed_spatial = get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid )
140+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid , output_type = "pt" )
126141
127142 # 2. Temporal
128143 grid_t = torch .arange (temporal_size , device = device , dtype = torch .float32 ) / temporal_interpolation_scale
129- pos_embed_temporal = get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t )
144+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t , output_type = "pt" )
130145
131146 # 3. Concat
132147 pos_embed_spatial = pos_embed_spatial [None , :, :]
@@ -141,6 +156,66 @@ def get_3d_sincos_pos_embed(
141156 return pos_embed
142157
143158
159+ def get_3d_sincos_pos_embed_np (
160+ embed_dim : int ,
161+ spatial_size : Union [int , Tuple [int , int ]],
162+ temporal_size : int ,
163+ spatial_interpolation_scale : float = 1.0 ,
164+ temporal_interpolation_scale : float = 1.0 ,
165+ ) -> np .ndarray :
166+ r"""
167+ Creates 3D sinusoidal positional embeddings.
168+
169+ Args:
170+ embed_dim (`int`):
171+ The embedding dimension of inputs. It must be divisible by 16.
172+ spatial_size (`int` or `Tuple[int, int]`):
173+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
174+ spatial dimensions (height and width).
175+ temporal_size (`int`):
176+ The temporal dimension of postional embeddings (number of frames).
177+ spatial_interpolation_scale (`float`, defaults to 1.0):
178+ Scale factor for spatial grid interpolation.
179+ temporal_interpolation_scale (`float`, defaults to 1.0):
180+ Scale factor for temporal grid interpolation.
181+
182+ Returns:
183+ `np.ndarray`:
184+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
185+ embed_dim]`.
186+ """
187+ if embed_dim % 4 != 0 :
188+ raise ValueError ("`embed_dim` must be divisible by 4" )
189+ if isinstance (spatial_size , int ):
190+ spatial_size = (spatial_size , spatial_size )
191+
192+ embed_dim_spatial = 3 * embed_dim // 4
193+ embed_dim_temporal = embed_dim // 4
194+
195+ # 1. Spatial
196+ grid_h = np .arange (spatial_size [1 ], dtype = np .float32 ) / spatial_interpolation_scale
197+ grid_w = np .arange (spatial_size [0 ], dtype = np .float32 ) / spatial_interpolation_scale
198+ grid = np .meshgrid (grid_w , grid_h ) # here w goes first
199+ grid = np .stack (grid , axis = 0 )
200+
201+ grid = grid .reshape ([2 , 1 , spatial_size [1 ], spatial_size [0 ]])
202+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid )
203+
204+ # 2. Temporal
205+ grid_t = np .arange (temporal_size , dtype = np .float32 ) / temporal_interpolation_scale
206+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t )
207+
208+ # 3. Concat
209+ pos_embed_spatial = pos_embed_spatial [np .newaxis , :, :]
210+ pos_embed_spatial = np .repeat (pos_embed_spatial , temporal_size , axis = 0 ) # [T, H*W, D // 4 * 3]
211+
212+ pos_embed_temporal = pos_embed_temporal [:, np .newaxis , :]
213+ pos_embed_temporal = np .repeat (pos_embed_temporal , spatial_size [0 ] * spatial_size [1 ], axis = 1 ) # [T, H*W, D // 4]
214+
215+ pos_embed = np .concatenate ([pos_embed_temporal , pos_embed_spatial ], axis = - 1 ) # [T, H*W, D]
216+ return pos_embed
217+
218+
144219def get_2d_sincos_pos_embed (
145220 embed_dim ,
146221 grid_size ,
@@ -149,6 +224,7 @@ def get_2d_sincos_pos_embed(
149224 interpolation_scale = 1.0 ,
150225 base_size = 16 ,
151226 device : Optional [torch .device ] = None ,
227+ output_type : str = "np" ,
152228):
153229 """
154230 Creates 2D sinusoidal positional embeddings.
@@ -170,6 +246,21 @@ def get_2d_sincos_pos_embed(
170246 Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
171247 embed_dim]` if using cls_token
172248 """
249+ if output_type == "np" :
250+ deprecation_message = (
251+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252+ " `from_numpy` is no longer required."
253+ " Pass `output_type='pt' to use the new version now."
254+ )
255+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
256+ return get_2d_sincos_pos_embed_np (
257+ embed_dim = embed_dim ,
258+ grid_size = grid_size ,
259+ cls_token = cls_token ,
260+ extra_tokens = extra_tokens ,
261+ interpolation_scale = interpolation_scale ,
262+ base_size = base_size ,
263+ )
173264 if isinstance (grid_size , int ):
174265 grid_size = (grid_size , grid_size )
175266
@@ -187,13 +278,13 @@ def get_2d_sincos_pos_embed(
187278 grid = torch .stack (grid , dim = 0 )
188279
189280 grid = grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
190- pos_embed = get_2d_sincos_pos_embed_from_grid (embed_dim , grid )
281+ pos_embed = get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = output_type )
191282 if cls_token and extra_tokens > 0 :
192283 pos_embed = torch .concat ([torch .zeros ([extra_tokens , embed_dim ]), pos_embed ], dim = 0 )
193284 return pos_embed
194285
195286
196- def get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
287+ def get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = "np" ):
197288 r"""
198289 This function generates 2D sinusoidal positional embeddings from a grid.
199290
@@ -204,18 +295,29 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
204295 Returns:
205296 `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
206297 """
298+ if output_type == "np" :
299+ deprecation_message = (
300+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301+ " `from_numpy` is no longer required."
302+ " Pass `output_type='pt' to use the new version now."
303+ )
304+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
305+ return get_2d_sincos_pos_embed_from_grid_np (
306+ embed_dim = embed_dim ,
307+ grid = grid ,
308+ )
207309 if embed_dim % 2 != 0 :
208310 raise ValueError ("embed_dim must be divisible by 2" )
209311
210312 # use half of dimensions to encode grid_h
211- emb_h = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [0 ]) # (H*W, D/2)
212- emb_w = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [1 ]) # (H*W, D/2)
313+ emb_h = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [0 ], output_type = output_type ) # (H*W, D/2)
314+ emb_w = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [1 ], output_type = output_type ) # (H*W, D/2)
213315
214316 emb = torch .concat ([emb_h , emb_w ], dim = 1 ) # (H*W, D)
215317 return emb
216318
217319
218- def get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
320+ def get_1d_sincos_pos_embed_from_grid (embed_dim , pos , output_type = "np" ):
219321 """
220322 This function generates 1D positional embeddings from a grid.
221323
@@ -226,6 +328,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
226328 Returns:
227329 `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
228330 """
331+ if output_type == "np" :
332+ deprecation_message = (
333+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334+ " `from_numpy` is no longer required."
335+ " Pass `output_type='pt' to use the new version now."
336+ )
337+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
338+ return get_1d_sincos_pos_embed_from_grid_np (embed_dim = embed_dim , pos = pos )
229339 if embed_dim % 2 != 0 :
230340 raise ValueError ("embed_dim must be divisible by 2" )
231341
@@ -243,6 +353,94 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
243353 return emb
244354
245355
356+ def get_2d_sincos_pos_embed_np (
357+ embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16
358+ ):
359+ """
360+ Creates 2D sinusoidal positional embeddings.
361+
362+ Args:
363+ embed_dim (`int`):
364+ The embedding dimension.
365+ grid_size (`int`):
366+ The size of the grid height and width.
367+ cls_token (`bool`, defaults to `False`):
368+ Whether or not to add a classification token.
369+ extra_tokens (`int`, defaults to `0`):
370+ The number of extra tokens to add.
371+ interpolation_scale (`float`, defaults to `1.0`):
372+ The scale of the interpolation.
373+
374+ Returns:
375+ pos_embed (`np.ndarray`):
376+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
377+ embed_dim]` if using cls_token
378+ """
379+ if isinstance (grid_size , int ):
380+ grid_size = (grid_size , grid_size )
381+
382+ grid_h = np .arange (grid_size [0 ], dtype = np .float32 ) / (grid_size [0 ] / base_size ) / interpolation_scale
383+ grid_w = np .arange (grid_size [1 ], dtype = np .float32 ) / (grid_size [1 ] / base_size ) / interpolation_scale
384+ grid = np .meshgrid (grid_w , grid_h ) # here w goes first
385+ grid = np .stack (grid , axis = 0 )
386+
387+ grid = grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
388+ pos_embed = get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid )
389+ if cls_token and extra_tokens > 0 :
390+ pos_embed = np .concatenate ([np .zeros ([extra_tokens , embed_dim ]), pos_embed ], axis = 0 )
391+ return pos_embed
392+
393+
394+ def get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid ):
395+ r"""
396+ This function generates 2D sinusoidal positional embeddings from a grid.
397+
398+ Args:
399+ embed_dim (`int`): The embedding dimension.
400+ grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
401+
402+ Returns:
403+ `np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
404+ """
405+ if embed_dim % 2 != 0 :
406+ raise ValueError ("embed_dim must be divisible by 2" )
407+
408+ # use half of dimensions to encode grid_h
409+ emb_h = get_1d_sincos_pos_embed_from_grid_np (embed_dim // 2 , grid [0 ]) # (H*W, D/2)
410+ emb_w = get_1d_sincos_pos_embed_from_grid_np (embed_dim // 2 , grid [1 ]) # (H*W, D/2)
411+
412+ emb = np .concatenate ([emb_h , emb_w ], axis = 1 ) # (H*W, D)
413+ return emb
414+
415+
416+ def get_1d_sincos_pos_embed_from_grid_np (embed_dim , pos ):
417+ """
418+ This function generates 1D positional embeddings from a grid.
419+
420+ Args:
421+ embed_dim (`int`): The embedding dimension `D`
422+ pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
423+
424+ Returns:
425+ `numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
426+ """
427+ if embed_dim % 2 != 0 :
428+ raise ValueError ("embed_dim must be divisible by 2" )
429+
430+ omega = np .arange (embed_dim // 2 , dtype = np .float64 )
431+ omega /= embed_dim / 2.0
432+ omega = 1.0 / 10000 ** omega # (D/2,)
433+
434+ pos = pos .reshape (- 1 ) # (M,)
435+ out = np .einsum ("m,d->md" , pos , omega ) # (M, D/2), outer product
436+
437+ emb_sin = np .sin (out ) # (M, D/2)
438+ emb_cos = np .cos (out ) # (M, D/2)
439+
440+ emb = np .concatenate ([emb_sin , emb_cos ], axis = 1 ) # (M, D)
441+ return emb
442+
443+
246444class PatchEmbed (nn .Module ):
247445 """
248446 2D Image to Patch Embedding with support for SD3 cropping.
@@ -305,7 +503,11 @@ def __init__(
305503 self .pos_embed = None
306504 elif pos_embed_type == "sincos" :
307505 pos_embed = get_2d_sincos_pos_embed (
308- embed_dim , grid_size , base_size = self .base_size , interpolation_scale = self .interpolation_scale
506+ embed_dim ,
507+ grid_size ,
508+ base_size = self .base_size ,
509+ interpolation_scale = self .interpolation_scale ,
510+ output_type = "pt" ,
309511 )
310512 persistent = True if pos_embed_max_size else False
311513 self .register_buffer ("pos_embed" , pos_embed .float ().unsqueeze (0 ), persistent = persistent )
@@ -359,6 +561,7 @@ def forward(self, latent):
359561 base_size = self .base_size ,
360562 interpolation_scale = self .interpolation_scale ,
361563 device = latent .device ,
564+ output_type = "pt" ,
362565 )
363566 pos_embed = pos_embed .float ().unsqueeze (0 )
364567 else :
@@ -486,6 +689,7 @@ def _get_positional_embeddings(
486689 self .spatial_interpolation_scale ,
487690 self .temporal_interpolation_scale ,
488691 device = device ,
692+ output_type = "pt" ,
489693 )
490694 pos_embedding = pos_embedding .flatten (0 , 1 )
491695 joint_pos_embedding = torch .zeros (
@@ -575,7 +779,9 @@ def __init__(
575779 # Linear projection for text embeddings
576780 self .text_proj = nn .Linear (text_hidden_size , hidden_size )
577781
578- pos_embed = get_2d_sincos_pos_embed (hidden_size , pos_embed_max_size , base_size = pos_embed_max_size )
782+ pos_embed = get_2d_sincos_pos_embed (
783+ hidden_size , pos_embed_max_size , base_size = pos_embed_max_size , output_type = "pt"
784+ )
579785 pos_embed = pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
580786 self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
581787
0 commit comments