Skip to content

Commit a96d993

Browse files
update
1 parent f6d7750 commit a96d993

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
286286
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
287287

288288
def forward(self, image_feature, pos_embed, key_padding_mask, temporal_embed=None):
289-
self._compile()
289+
self.compile()
290290
if temporal_embed is not None:
291291
result = self.request(
292292
{
@@ -2020,7 +2020,8 @@ def resampling(self, x, tgt_sizes, temporal_ids=None):
20202020

20212021
max_patch_len = torch.max(patch_len)
20222022
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool)
2023-
2023+
2024+
temporal_embed = None
20242025
pos_embed = []
20252026
pos_embed_temporal = []
20262027
for i in range(bs):
@@ -2038,21 +2039,16 @@ def resampling(self, x, tgt_sizes, temporal_ids=None):
20382039
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
20392040
1, 0, 2
20402041
) # BLD => L * B * D
2041-
if pos_embed_temporal:
2042-
temporal_embed = torch.stack(pos_embed_temporal, dim=0).unsqueeze(0)
2043-
res = torch.from_numpy(
2044-
self.resampler(
2045-
image_feature=x,
2046-
pos_embed=pos_embed,
2047-
key_padding_mask=key_padding_mask,
2048-
temporal_embed=temporal_embed,
2049-
)
2050-
)
2051-
else:
2052-
# Print shapes of all inputs to resampler
2053-
res = torch.from_numpy(
2054-
self.resampler(image_feature=x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
2042+
2043+
temporal_embed = torch.stack(pos_embed_temporal, dim=0).unsqueeze(0)
2044+
res = torch.from_numpy(
2045+
self.resampler(
2046+
image_feature=x,
2047+
pos_embed=pos_embed,
2048+
key_padding_mask=key_padding_mask,
2049+
temporal_embed=temporal_embed,
20552050
)
2051+
)
20562052
return res
20572053

20582054
def _set_2d_pos_cache(self, max_size):
@@ -4487,4 +4483,4 @@ def preprocess_inputs(
44874483
"phi4_multimodal": _OVPhi4MMForCausalLM,
44884484
"llama4": _OVLlama4ForCausalLM,
44894485
"minicpmo": _OVMiniCPMOForCausalLM,
4490-
}
4486+
}

0 commit comments

Comments
 (0)