@@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
8484    temporal_size : int ,
8585    spatial_interpolation_scale : float  =  1.0 ,
8686    temporal_interpolation_scale : float  =  1.0 ,
87+     device : Optional [torch .device ] =  None ,
88+     output_type : str  =  "np" ,
89+ ) ->  torch .Tensor :
90+     r""" 
91+     Creates 3D sinusoidal positional embeddings. 
92+ 
93+     Args: 
94+         embed_dim (`int`): 
95+             The embedding dimension of inputs. It must be divisible by 16. 
96+         spatial_size (`int` or `Tuple[int, int]`): 
97+             The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both 
98+             spatial dimensions (height and width). 
99+         temporal_size (`int`): 
100+             The temporal dimension of postional embeddings (number of frames). 
101+         spatial_interpolation_scale (`float`, defaults to 1.0): 
102+             Scale factor for spatial grid interpolation. 
103+         temporal_interpolation_scale (`float`, defaults to 1.0): 
104+             Scale factor for temporal grid interpolation. 
105+ 
106+     Returns: 
107+         `torch.Tensor`: 
108+             The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], 
109+             embed_dim]`. 
110+     """ 
111+     if  output_type  ==  "np" :
112+         return  _get_3d_sincos_pos_embed_np (
113+             embed_dim = embed_dim ,
114+             spatial_size = spatial_size ,
115+             temporal_size = temporal_size ,
116+             spatial_interpolation_scale = spatial_interpolation_scale ,
117+             temporal_interpolation_scale = temporal_interpolation_scale ,
118+         )
119+     if  embed_dim  %  4  !=  0 :
120+         raise  ValueError ("`embed_dim` must be divisible by 4" )
121+     if  isinstance (spatial_size , int ):
122+         spatial_size  =  (spatial_size , spatial_size )
123+ 
124+     embed_dim_spatial  =  3  *  embed_dim  //  4 
125+     embed_dim_temporal  =  embed_dim  //  4 
126+ 
127+     # 1. Spatial 
128+     grid_h  =  torch .arange (spatial_size [1 ], device = device , dtype = torch .float32 ) /  spatial_interpolation_scale 
129+     grid_w  =  torch .arange (spatial_size [0 ], device = device , dtype = torch .float32 ) /  spatial_interpolation_scale 
130+     grid  =  torch .meshgrid (grid_w , grid_h , indexing = "xy" )  # here w goes first 
131+     grid  =  torch .stack (grid , dim = 0 )
132+ 
133+     grid  =  grid .reshape ([2 , 1 , spatial_size [1 ], spatial_size [0 ]])
134+     pos_embed_spatial  =  get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid , output_type = "pt" )
135+ 
136+     # 2. Temporal 
137+     grid_t  =  torch .arange (temporal_size , device = device , dtype = torch .float32 ) /  temporal_interpolation_scale 
138+     pos_embed_temporal  =  get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t , output_type = "pt" )
139+ 
140+     # 3. Concat 
141+     pos_embed_spatial  =  pos_embed_spatial [None , :, :]
142+     pos_embed_spatial  =  pos_embed_spatial .repeat_interleave (temporal_size , dim = 0 )  # [T, H*W, D // 4 * 3] 
143+ 
144+     pos_embed_temporal  =  pos_embed_temporal [:, None , :]
145+     pos_embed_temporal  =  pos_embed_temporal .repeat_interleave (
146+         spatial_size [0 ] *  spatial_size [1 ], dim = 1 
147+     )  # [T, H*W, D // 4] 
148+ 
149+     pos_embed  =  torch .concat ([pos_embed_temporal , pos_embed_spatial ], dim = - 1 )  # [T, H*W, D] 
150+     return  pos_embed 
151+ 
152+ 
153+ def  _get_3d_sincos_pos_embed_np (
154+     embed_dim : int ,
155+     spatial_size : Union [int , Tuple [int , int ]],
156+     temporal_size : int ,
157+     spatial_interpolation_scale : float  =  1.0 ,
158+     temporal_interpolation_scale : float  =  1.0 ,
87159) ->  np .ndarray :
88160    r""" 
89161    Creates 3D sinusoidal positional embeddings. 
@@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
106178            The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], 
107179            embed_dim]`. 
108180    """ 
181+     deprecation_message  =  (
182+         "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." 
183+         " `from_numpy` is no longer required." 
184+         "  Pass `output_type='pt' to use the new version now." 
185+     )
186+     deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
109187    if  embed_dim  %  4  !=  0 :
110188        raise  ValueError ("`embed_dim` must be divisible by 4" )
111189    if  isinstance (spatial_size , int ):
@@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(
139217
140218
141219def  get_2d_sincos_pos_embed (
220+     embed_dim ,
221+     grid_size ,
222+     cls_token = False ,
223+     extra_tokens = 0 ,
224+     interpolation_scale = 1.0 ,
225+     base_size = 16 ,
226+     device : Optional [torch .device ] =  None ,
227+     output_type : str  =  "np" ,
228+ ):
229+     """ 
230+     Creates 2D sinusoidal positional embeddings. 
231+ 
232+     Args: 
233+         embed_dim (`int`): 
234+             The embedding dimension. 
235+         grid_size (`int`): 
236+             The size of the grid height and width. 
237+         cls_token (`bool`, defaults to `False`): 
238+             Whether or not to add a classification token. 
239+         extra_tokens (`int`, defaults to `0`): 
240+             The number of extra tokens to add. 
241+         interpolation_scale (`float`, defaults to `1.0`): 
242+             The scale of the interpolation. 
243+ 
244+     Returns: 
245+         pos_embed (`torch.Tensor`): 
246+             Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, 
247+             embed_dim]` if using cls_token 
248+     """ 
249+     if  output_type  ==  "np" :
250+         deprecation_message  =  (
251+             "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." 
252+             " `from_numpy` is no longer required." 
253+             "  Pass `output_type='pt' to use the new version now." 
254+         )
255+         deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
256+         return  get_2d_sincos_pos_embed_np (
257+             embed_dim = embed_dim ,
258+             grid_size = grid_size ,
259+             cls_token = cls_token ,
260+             extra_tokens = extra_tokens ,
261+             interpolation_scale = interpolation_scale ,
262+             base_size = base_size ,
263+         )
264+     if  isinstance (grid_size , int ):
265+         grid_size  =  (grid_size , grid_size )
266+ 
267+     grid_h  =  (
268+         torch .arange (grid_size [0 ], device = device , dtype = torch .float32 )
269+         /  (grid_size [0 ] /  base_size )
270+         /  interpolation_scale 
271+     )
272+     grid_w  =  (
273+         torch .arange (grid_size [1 ], device = device , dtype = torch .float32 )
274+         /  (grid_size [1 ] /  base_size )
275+         /  interpolation_scale 
276+     )
277+     grid  =  torch .meshgrid (grid_w , grid_h , indexing = "xy" )  # here w goes first 
278+     grid  =  torch .stack (grid , dim = 0 )
279+ 
280+     grid  =  grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
281+     pos_embed  =  get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = output_type )
282+     if  cls_token  and  extra_tokens  >  0 :
283+         pos_embed  =  torch .concat ([torch .zeros ([extra_tokens , embed_dim ]), pos_embed ], dim = 0 )
284+     return  pos_embed 
285+ 
286+ 
287+ def  get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = "np" ):
288+     r""" 
289+     This function generates 2D sinusoidal positional embeddings from a grid. 
290+ 
291+     Args: 
292+         embed_dim (`int`): The embedding dimension. 
293+         grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. 
294+ 
295+     Returns: 
296+         `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` 
297+     """ 
298+     if  output_type  ==  "np" :
299+         deprecation_message  =  (
300+             "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." 
301+             " `from_numpy` is no longer required." 
302+             "  Pass `output_type='pt' to use the new version now." 
303+         )
304+         deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
305+         return  get_2d_sincos_pos_embed_from_grid_np (
306+             embed_dim = embed_dim ,
307+             grid = grid ,
308+         )
309+     if  embed_dim  %  2  !=  0 :
310+         raise  ValueError ("embed_dim must be divisible by 2" )
311+ 
312+     # use half of dimensions to encode grid_h 
313+     emb_h  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [0 ], output_type = output_type )  # (H*W, D/2) 
314+     emb_w  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [1 ], output_type = output_type )  # (H*W, D/2) 
315+ 
316+     emb  =  torch .concat ([emb_h , emb_w ], dim = 1 )  # (H*W, D) 
317+     return  emb 
318+ 
319+ 
320+ def  get_1d_sincos_pos_embed_from_grid (embed_dim , pos , output_type = "np" ):
321+     """ 
322+     This function generates 1D positional embeddings from a grid. 
323+ 
324+     Args: 
325+         embed_dim (`int`): The embedding dimension `D` 
326+         pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` 
327+ 
328+     Returns: 
329+         `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. 
330+     """ 
331+     if  output_type  ==  "np" :
332+         deprecation_message  =  (
333+             "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." 
334+             " `from_numpy` is no longer required." 
335+             "  Pass `output_type='pt' to use the new version now." 
336+         )
337+         deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
338+         return  get_1d_sincos_pos_embed_from_grid_np (embed_dim = embed_dim , pos = pos )
339+     if  embed_dim  %  2  !=  0 :
340+         raise  ValueError ("embed_dim must be divisible by 2" )
341+ 
342+     omega  =  torch .arange (embed_dim  //  2 , device = pos .device , dtype = torch .float64 )
343+     omega  /=  embed_dim  /  2.0 
344+     omega  =  1.0  /  10000 ** omega   # (D/2,) 
345+ 
346+     pos  =  pos .reshape (- 1 )  # (M,) 
347+     out  =  torch .outer (pos , omega )  # (M, D/2), outer product 
348+ 
349+     emb_sin  =  torch .sin (out )  # (M, D/2) 
350+     emb_cos  =  torch .cos (out )  # (M, D/2) 
351+ 
352+     emb  =  torch .concat ([emb_sin , emb_cos ], dim = 1 )  # (M, D) 
353+     return  emb 
354+ 
355+ 
356+ def  get_2d_sincos_pos_embed_np (
142357    embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16 
143358):
144359    """ 
@@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
170385    grid  =  np .stack (grid , axis = 0 )
171386
172387    grid  =  grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
173-     pos_embed  =  get_2d_sincos_pos_embed_from_grid (embed_dim , grid )
388+     pos_embed  =  get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid )
174389    if  cls_token  and  extra_tokens  >  0 :
175390        pos_embed  =  np .concatenate ([np .zeros ([extra_tokens , embed_dim ]), pos_embed ], axis = 0 )
176391    return  pos_embed 
177392
178393
179- def  get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
394+ def  get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid ):
180395    r""" 
181396    This function generates 2D sinusoidal positional embeddings from a grid. 
182397
@@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
191406        raise  ValueError ("embed_dim must be divisible by 2" )
192407
193408    # use half of dimensions to encode grid_h 
194-     emb_h  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [0 ])  # (H*W, D/2) 
195-     emb_w  =  get_1d_sincos_pos_embed_from_grid (embed_dim  //  2 , grid [1 ])  # (H*W, D/2) 
409+     emb_h  =  get_1d_sincos_pos_embed_from_grid_np (embed_dim  //  2 , grid [0 ])  # (H*W, D/2) 
410+     emb_w  =  get_1d_sincos_pos_embed_from_grid_np (embed_dim  //  2 , grid [1 ])  # (H*W, D/2) 
196411
197412    emb  =  np .concatenate ([emb_h , emb_w ], axis = 1 )  # (H*W, D) 
198413    return  emb 
199414
200415
201- def  get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
416+ def  get_1d_sincos_pos_embed_from_grid_np (embed_dim , pos ):
202417    """ 
203418    This function generates 1D positional embeddings from a grid. 
204419
@@ -288,10 +503,14 @@ def __init__(
288503            self .pos_embed  =  None 
289504        elif  pos_embed_type  ==  "sincos" :
290505            pos_embed  =  get_2d_sincos_pos_embed (
291-                 embed_dim , grid_size , base_size = self .base_size , interpolation_scale = self .interpolation_scale 
506+                 embed_dim ,
507+                 grid_size ,
508+                 base_size = self .base_size ,
509+                 interpolation_scale = self .interpolation_scale ,
510+                 output_type = "pt" ,
292511            )
293512            persistent  =  True  if  pos_embed_max_size  else  False 
294-             self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ), persistent = persistent )
513+             self .register_buffer ("pos_embed" , pos_embed .float ().unsqueeze (0 ), persistent = persistent )
295514        else :
296515            raise  ValueError (f"Unsupported pos_embed_type: { pos_embed_type }  )
297516
@@ -341,8 +560,10 @@ def forward(self, latent):
341560                    grid_size = (height , width ),
342561                    base_size = self .base_size ,
343562                    interpolation_scale = self .interpolation_scale ,
563+                     device = latent .device ,
564+                     output_type = "pt" ,
344565                )
345-                 pos_embed  =  torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ). to ( latent . device )
566+                 pos_embed  =  pos_embed .float ().unsqueeze (0 )
346567            else :
347568                pos_embed  =  self .pos_embed 
348569
@@ -453,7 +674,9 @@ def __init__(
453674            pos_embedding  =  self ._get_positional_embeddings (sample_height , sample_width , sample_frames )
454675            self .register_buffer ("pos_embedding" , pos_embedding , persistent = persistent )
455676
456-     def  _get_positional_embeddings (self , sample_height : int , sample_width : int , sample_frames : int ) ->  torch .Tensor :
677+     def  _get_positional_embeddings (
678+         self , sample_height : int , sample_width : int , sample_frames : int , device : Optional [torch .device ] =  None 
679+     ) ->  torch .Tensor :
457680        post_patch_height  =  sample_height  //  self .patch_size 
458681        post_patch_width  =  sample_width  //  self .patch_size 
459682        post_time_compression_frames  =  (sample_frames  -  1 ) //  self .temporal_compression_ratio  +  1 
@@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
465688            post_time_compression_frames ,
466689            self .spatial_interpolation_scale ,
467690            self .temporal_interpolation_scale ,
691+             device = device ,
692+             output_type = "pt" ,
468693        )
469-         pos_embedding  =  torch . from_numpy ( pos_embedding ) .flatten (0 , 1 )
694+         pos_embedding  =  pos_embedding .flatten (0 , 1 )
470695        joint_pos_embedding  =  torch .zeros (
471696            1 , self .max_text_seq_length  +  num_patches , self .embed_dim , requires_grad = False 
472697        )
@@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
521746                or  self .sample_width  !=  width 
522747                or  self .sample_frames  !=  pre_time_compression_frames 
523748            ):
524-                 pos_embedding  =  self ._get_positional_embeddings (height , width , pre_time_compression_frames )
525-                 pos_embedding  =  pos_embedding .to (embeds .device , dtype = embeds .dtype )
749+                 pos_embedding  =  self ._get_positional_embeddings (
750+                     height , width , pre_time_compression_frames , device = embeds .device 
751+                 )
752+                 pos_embedding  =  pos_embedding .to (dtype = embeds .dtype )
526753            else :
527754                pos_embedding  =  self .pos_embedding 
528755
@@ -552,9 +779,11 @@ def __init__(
552779        # Linear projection for text embeddings 
553780        self .text_proj  =  nn .Linear (text_hidden_size , hidden_size )
554781
555-         pos_embed  =  get_2d_sincos_pos_embed (hidden_size , pos_embed_max_size , base_size = pos_embed_max_size )
782+         pos_embed  =  get_2d_sincos_pos_embed (
783+             hidden_size , pos_embed_max_size , base_size = pos_embed_max_size , output_type = "pt" 
784+         )
556785        pos_embed  =  pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
557-         self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float (), persistent = False )
786+         self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
558787
559788    def  forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) ->  torch .Tensor :
560789        batch_size , channel , height , width  =  hidden_states .shape 
0 commit comments