@@ -287,19 +287,10 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
287287
288288 def forward (self , image_feature , pos_embed , key_padding_mask , temporal_embed = None ):
289289 self .compile ()
290+ inputs = {"image_feature" : image_feature , "pos_embed" : pos_embed , "key_padding_mask" : key_padding_mask }
290291 if temporal_embed is not None :
291- result = self .request (
292- {
293- "image_feature" : image_feature ,
294- "pos_embed" : pos_embed ,
295- "key_padding_mask" : key_padding_mask ,
296- "temporal_embed" : temporal_embed ,
297- }
298- )[0 ]
299- else :
300- result = self .request (
301- {"image_feature" : image_feature , "pos_embed" : pos_embed , "key_padding_mask" : key_padding_mask }
302- )[0 ]
292+ inputs ["temporal_embed" ] = temporal_embed
293+ result = self .request (inputs )[0 ]
303294 return result
304295
305296
@@ -2000,6 +1991,40 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, temporal_ids=None,
20001991 vision_hidden_states .append (dummy_feature )
20011992 return vision_hidden_states
20021993
1994+ def get_1d_sincos_pos_embed_from_temporal_size (self , embed_dim , pos ):
1995+ """
1996+ embed_dim: output dimension for each position
1997+ pos: a list of positions to be encoded: size (M,)
1998+ out: (M, D)
1999+ """
2000+ assert embed_dim % 2 == 0
2001+ omega = np .arange (embed_dim // 2 , dtype = np .float32 )
2002+ omega /= embed_dim / 2.0
2003+ omega = 1.0 / 10000 ** omega # (D/2,)
2004+
2005+ pos = pos .reshape (- 1 ) # (M,)
2006+ out = np .einsum ("m,d->md" , pos , omega ) # (M, D/2), outer product
2007+
2008+ emb_sin = np .sin (out ) # (M, D/2)
2009+ emb_cos = np .cos (out ) # (M, D/2)
2010+
2011+ emb = np .concatenate ([emb_sin , emb_cos ], axis = 1 ) # (M, D)
2012+ return emb
2013+
2014+ def _set_temporal_pos_cache (self , max_temporal_size , device = "cpu" ):
2015+ temporal_size = np .arange (max_temporal_size , dtype = np .float32 )
2016+ pos_embed = (
2017+ torch .from_numpy (self .get_1d_sincos_pos_embed_from_temporal_size (self .embed_dim , temporal_size ))
2018+ .float ()
2019+ .to (device )
2020+ )
2021+ self .temporal_pos_embed = pos_embed
2022+
2023+ def _adjust_temporal_pos_cache (self , max_temporal_size , device ):
2024+ if max_temporal_size > self .max_temporal_size :
2025+ self .max_temporal_size = max_temporal_size
2026+ self ._set_temporal_pos_cache (self .max_temporal_size , device )
2027+
20032028 def resampling (self , x , tgt_sizes , temporal_ids = None ):
20042029 from itertools import chain
20052030
0 commit comments