@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
4141
4242
4343class VideoPositionEmb (nn .Module ):
44- def forward (self , x_B_T_H_W_C : torch .Tensor , fps = Optional [torch .Tensor ], device = None ) -> torch .Tensor :
44+ def forward (self , x_B_T_H_W_C : torch .Tensor , fps = Optional [torch .Tensor ], device = None , dtype = None ) -> torch .Tensor :
4545 """
4646 It delegates the embedding generation to generate_embeddings function.
4747 """
4848 B_T_H_W_C = x_B_T_H_W_C .shape
49- embeddings = self .generate_embeddings (B_T_H_W_C , fps = fps , device = device )
49+ embeddings = self .generate_embeddings (B_T_H_W_C , fps = fps , device = device , dtype = dtype )
5050
5151 return embeddings
5252
@@ -104,6 +104,7 @@ def generate_embeddings(
104104 w_ntk_factor : Optional [float ] = None ,
105105 t_ntk_factor : Optional [float ] = None ,
106106 device = None ,
107+ dtype = None ,
107108 ):
108109 """
109110 Generate embeddings for the given input size.
@@ -189,13 +190,12 @@ def __init__(
189190 self .pos_emb_w = nn .Parameter (torch .empty (len_w , model_channels , device = device , dtype = dtype ))
190191 self .pos_emb_t = nn .Parameter (torch .empty (len_t , model_channels , device = device , dtype = dtype ))
191192
192-
193- def generate_embeddings (self , B_T_H_W_C : torch .Size , fps = Optional [torch .Tensor ], device = None ) -> torch .Tensor :
193+ def generate_embeddings (self , B_T_H_W_C : torch .Size , fps = Optional [torch .Tensor ], device = None , dtype = None ) -> torch .Tensor :
194194 B , T , H , W , _ = B_T_H_W_C
195195 if self .interpolation == "crop" :
196- emb_h_H = self .pos_emb_h [:H ].to (device = device )
197- emb_w_W = self .pos_emb_w [:W ].to (device = device )
198- emb_t_T = self .pos_emb_t [:T ].to (device = device )
196+ emb_h_H = self .pos_emb_h [:H ].to (device = device , dtype = dtype )
197+ emb_w_W = self .pos_emb_w [:W ].to (device = device , dtype = dtype )
198+ emb_t_T = self .pos_emb_t [:T ].to (device = device , dtype = dtype )
199199 emb = (
200200 repeat (emb_t_T , "t d-> b t h w d" , b = B , h = H , w = W )
201201 + repeat (emb_h_H , "h d-> b t h w d" , b = B , t = T , w = W )
0 commit comments