@@ -84,7 +84,8 @@ def get_3d_sincos_pos_embed(
8484 temporal_size : int ,
8585 spatial_interpolation_scale : float = 1.0 ,
8686 temporal_interpolation_scale : float = 1.0 ,
87- ) -> np .ndarray :
87+ device : Optional [torch .device ] = None ,
88+ ) -> torch .Tensor :
8889 r"""
8990 Creates 3D sinusoidal positional embeddings.
9091
@@ -102,7 +103,7 @@ def get_3d_sincos_pos_embed(
102103 Scale factor for temporal grid interpolation.
103104
104105 Returns:
105- `np.ndarray `:
106+ `torch.Tensor `:
106107 The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107108 embed_dim]`.
108109 """
@@ -115,26 +116,28 @@ def get_3d_sincos_pos_embed(
115116 embed_dim_temporal = embed_dim // 4
116117
117118 # 1. Spatial
118- grid_h = np .arange (spatial_size [1 ], dtype = np .float32 ) / spatial_interpolation_scale
119- grid_w = np .arange (spatial_size [0 ], dtype = np .float32 ) / spatial_interpolation_scale
120- grid = np .meshgrid (grid_w , grid_h ) # here w goes first
121- grid = np .stack (grid , axis = 0 )
119+ grid_h = torch .arange (spatial_size [1 ], device = device , dtype = torch .float32 ) / spatial_interpolation_scale
120+ grid_w = torch .arange (spatial_size [0 ], device = device , dtype = torch .float32 ) / spatial_interpolation_scale
121+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" ) # here w goes first
122+ grid = torch .stack (grid , dim = 0 )
122123
123124 grid = grid .reshape ([2 , 1 , spatial_size [1 ], spatial_size [0 ]])
124125 pos_embed_spatial = get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid )
125126
126127 # 2. Temporal
127- grid_t = np .arange (temporal_size , dtype = np .float32 ) / temporal_interpolation_scale
128+ grid_t = torch .arange (temporal_size , device = device , dtype = torch .float32 ) / temporal_interpolation_scale
128129 pos_embed_temporal = get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t )
129130
130131 # 3. Concat
131- pos_embed_spatial = pos_embed_spatial [np . newaxis , :, :]
132- pos_embed_spatial = np . repeat ( pos_embed_spatial , temporal_size , axis = 0 ) # [T, H*W, D // 4 * 3]
132+ pos_embed_spatial = pos_embed_spatial [None , :, :]
133+ pos_embed_spatial = pos_embed_spatial . repeat_interleave ( temporal_size , dim = 0 ) # [T, H*W, D // 4 * 3]
133134
134- pos_embed_temporal = pos_embed_temporal [:, np .newaxis , :]
135- pos_embed_temporal = np .repeat (pos_embed_temporal , spatial_size [0 ] * spatial_size [1 ], axis = 1 ) # [T, H*W, D // 4]
135+ pos_embed_temporal = pos_embed_temporal [:, None , :]
136+ pos_embed_temporal = pos_embed_temporal .repeat_interleave (
137+ spatial_size [0 ] * spatial_size [1 ], dim = 1
138+ ) # [T, H*W, D // 4]
136139
137- pos_embed = np . concatenate ([pos_embed_temporal , pos_embed_spatial ], axis = - 1 ) # [T, H*W, D]
140+ pos_embed = torch . concat ([pos_embed_temporal , pos_embed_spatial ], dim = - 1 ) # [T, H*W, D]
138141 return pos_embed
139142
140143
@@ -468,7 +471,9 @@ def __init__(
468471 pos_embedding = self ._get_positional_embeddings (sample_height , sample_width , sample_frames )
469472 self .register_buffer ("pos_embedding" , pos_embedding , persistent = persistent )
470473
471- def _get_positional_embeddings (self , sample_height : int , sample_width : int , sample_frames : int ) -> torch .Tensor :
474+ def _get_positional_embeddings (
475+ self , sample_height : int , sample_width : int , sample_frames : int , device : Optional [torch .device ] = None
476+ ) -> torch .Tensor :
472477 post_patch_height = sample_height // self .patch_size
473478 post_patch_width = sample_width // self .patch_size
474479 post_time_compression_frames = (sample_frames - 1 ) // self .temporal_compression_ratio + 1
@@ -480,8 +485,9 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
480485 post_time_compression_frames ,
481486 self .spatial_interpolation_scale ,
482487 self .temporal_interpolation_scale ,
488+ device = device ,
483489 )
484- pos_embedding = torch . from_numpy ( pos_embedding ) .flatten (0 , 1 )
490+ pos_embedding = pos_embedding .flatten (0 , 1 )
485491 joint_pos_embedding = torch .zeros (
486492 1 , self .max_text_seq_length + num_patches , self .embed_dim , requires_grad = False
487493 )
@@ -536,8 +542,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
536542 or self .sample_width != width
537543 or self .sample_frames != pre_time_compression_frames
538544 ):
539- pos_embedding = self ._get_positional_embeddings (height , width , pre_time_compression_frames )
540- pos_embedding = pos_embedding .to (embeds .device , dtype = embeds .dtype )
545+ pos_embedding = self ._get_positional_embeddings (
546+ height , width , pre_time_compression_frames , device = embeds .device
547+ )
548+ pos_embedding = pos_embedding .to (dtype = embeds .dtype )
541549 else :
542550 pos_embedding = self .pos_embedding
543551
0 commit comments