Skip to content

Commit 09a3f19

Browse files
to algin with orignal model
1 parent 4aff6ed commit 09a3f19

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,19 +287,10 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
287287

288288
def forward(self, image_feature, pos_embed, key_padding_mask, temporal_embed=None):
289289
self.compile()
290+
inputs = {"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask}
290291
if temporal_embed is not None:
291-
result = self.request(
292-
{
293-
"image_feature": image_feature,
294-
"pos_embed": pos_embed,
295-
"key_padding_mask": key_padding_mask,
296-
"temporal_embed": temporal_embed,
297-
}
298-
)[0]
299-
else:
300-
result = self.request(
301-
{"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask}
302-
)[0]
292+
inputs["temporal_embed"] = temporal_embed
293+
result = self.request(inputs)[0]
303294
return result
304295

305296

@@ -2000,6 +1991,40 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, temporal_ids=None,
20001991
vision_hidden_states.append(dummy_feature)
20011992
return vision_hidden_states
20021993

1994+
def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim, pos):
1995+
"""
1996+
embed_dim: output dimension for each position
1997+
pos: a list of positions to be encoded: size (M,)
1998+
out: (M, D)
1999+
"""
2000+
assert embed_dim % 2 == 0
2001+
omega = np.arange(embed_dim // 2, dtype=np.float32)
2002+
omega /= embed_dim / 2.0
2003+
omega = 1.0 / 10000**omega # (D/2,)
2004+
2005+
pos = pos.reshape(-1) # (M,)
2006+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
2007+
2008+
emb_sin = np.sin(out) # (M, D/2)
2009+
emb_cos = np.cos(out) # (M, D/2)
2010+
2011+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
2012+
return emb
2013+
2014+
def _set_temporal_pos_cache(self, max_temporal_size, device="cpu"):
2015+
temporal_size = np.arange(max_temporal_size, dtype=np.float32)
2016+
pos_embed = (
2017+
torch.from_numpy(self.get_1d_sincos_pos_embed_from_temporal_size(self.embed_dim, temporal_size))
2018+
.float()
2019+
.to(device)
2020+
)
2021+
self.temporal_pos_embed = pos_embed
2022+
2023+
def _adjust_temporal_pos_cache(self, max_temporal_size, device):
2024+
if max_temporal_size > self.max_temporal_size:
2025+
self.max_temporal_size = max_temporal_size
2026+
self._set_temporal_pos_cache(self.max_temporal_size, device)
2027+
20032028
def resampling(self, x, tgt_sizes, temporal_ids=None):
20042029
from itertools import chain
20052030

0 commit comments

Comments
 (0)