@@ -1941,6 +1941,8 @@ def __init__(
19411941 def get_vision_embeddings (self , pixel_values , input_ids = None , temporal_ids = None , ** kwargs ):
19421942 if input_ids is not None and input_ids .shape [1 ] == 1 :
19431943 return None
1944+
1945+ all_temporal_ids = None
19441946 if temporal_ids is not None :
19451947 all_temporal_ids = []
19461948 for t in temporal_ids :
@@ -2020,7 +2022,7 @@ def resampling(self, x, tgt_sizes, temporal_ids=None):
20202022
20212023 max_patch_len = torch .max (patch_len )
20222024 key_padding_mask = torch .zeros ((bs , max_patch_len ), dtype = torch .bool )
2023-
2025+
20242026 temporal_embed = None
20252027 pos_embed = []
20262028 pos_embed_temporal = []
@@ -2039,8 +2041,8 @@ def resampling(self, x, tgt_sizes, temporal_ids=None):
20392041 pos_embed = torch .nn .utils .rnn .pad_sequence (pos_embed , batch_first = True , padding_value = 0.0 ).permute (
20402042 1 , 0 , 2
20412043 ) # BLD => L * B * D
2042-
2043- temporal_embed = torch .stack (pos_embed_temporal , dim = 0 ).unsqueeze (0 )
2044+ if temporal_pos_emb :
2045+ temporal_embed = torch .stack (pos_embed_temporal , dim = 0 ).unsqueeze (0 )
20442046 res = torch .from_numpy (
20452047 self .resampler (
20462048 image_feature = x ,
@@ -4483,4 +4485,4 @@ def preprocess_inputs(
44834485 "phi4_multimodal" : _OVPhi4MMForCausalLM ,
44844486 "llama4" : _OVLlama4ForCausalLM ,
44854487 "minicpmo" : _OVMiniCPMOForCausalLM ,
4486- }
4488+ }
0 commit comments