-
Notifications
You must be signed in to change notification settings - Fork 153
[OpenVINO]add support for minicpmv4/4_5 #1412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
4b70b7f
7c64417
4b245f1
81f69be
6974a3e
bd8b41f
6c0c617
05fad58
b4e2ce1
39223a4
bd1adbd
6aec742
f6d7750
a96d993
02a4acf
395b0ca
4aff6ed
09a3f19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -285,11 +285,21 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None: | |||||||||||||
| self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs} | ||||||||||||||
| self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} | ||||||||||||||
|
|
||||||||||||||
| def forward(self, image_feature, pos_embed, key_padding_mask): | ||||||||||||||
| def forward(self, image_feature, pos_embed, key_padding_mask, temporal_embed=None): | ||||||||||||||
| self.compile() | ||||||||||||||
| result = self.request( | ||||||||||||||
| {"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask} | ||||||||||||||
| )[0] | ||||||||||||||
| if temporal_embed is not None: | ||||||||||||||
| result = self.request( | ||||||||||||||
| { | ||||||||||||||
| "image_feature": image_feature, | ||||||||||||||
| "pos_embed": pos_embed, | ||||||||||||||
| "key_padding_mask": key_padding_mask, | ||||||||||||||
| "temporal_embed": temporal_embed, | ||||||||||||||
| } | ||||||||||||||
| )[0] | ||||||||||||||
| else: | ||||||||||||||
| result = self.request( | ||||||||||||||
| {"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask} | ||||||||||||||
| )[0] | ||||||||||||||
| return result | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
|
|
@@ -784,6 +794,7 @@ def forward( | |||||||||||||
| audio_embed_sizes=None, | ||||||||||||||
| audio_attention_mask=None, | ||||||||||||||
| input_mode=None, | ||||||||||||||
| temporal_ids=None, | ||||||||||||||
| **kwargs, | ||||||||||||||
| ): | ||||||||||||||
| if pixel_values is None: | ||||||||||||||
|
|
@@ -809,6 +820,7 @@ def forward( | |||||||||||||
| audio_embed_sizes=audio_embed_sizes, | ||||||||||||||
| audio_attention_mask=audio_attention_mask, | ||||||||||||||
| input_mode=input_mode, | ||||||||||||||
| temporal_ids=temporal_ids, | ||||||||||||||
| **kwargs, | ||||||||||||||
| ) | ||||||||||||||
| return self.language_model.forward( | ||||||||||||||
|
|
@@ -921,6 +933,7 @@ def prepare_inputs_for_generation( | |||||||||||||
| "input_audio_embeds": kwargs.get("input_audio_embeds", kwargs.get("audio_input_features")), | ||||||||||||||
| "audio_embed_sizes": kwargs.get("audio_embed_sizes"), | ||||||||||||||
| "input_mode": kwargs.get("input_mode"), | ||||||||||||||
| "temporal_ids": kwargs.get("temporal_ids"), | ||||||||||||||
| } | ||||||||||||||
| ) | ||||||||||||||
| return model_inputs | ||||||||||||||
|
|
@@ -1923,10 +1936,18 @@ def __init__( | |||||||||||||
| max_size = self.config.vision_config.image_size // self.config.vision_config.patch_size | ||||||||||||||
| self._pos_embeds = torch.from_numpy(self._get_2d_sincos_pos_embed(self.embed_dim, max_size)).float() | ||||||||||||||
| self.max_size = (max_size, max_size) | ||||||||||||||
| self.max_temporal_size = 72000 | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 72000? Should this value be loaded from the config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Its default for this model: https://huggingface.co/openbmb/MiniCPM-V-4_5/blob/main/resampler.py#L100 and not initialized from config: https://huggingface.co/openbmb/MiniCPM-V-4_5/blob/main/modeling_minicpmv.py#L50 |
||||||||||||||
|
|
||||||||||||||
| def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): | ||||||||||||||
| def get_vision_embeddings(self, pixel_values, input_ids=None, temporal_ids=None, **kwargs): | ||||||||||||||
| if input_ids is not None and input_ids.shape[1] == 1: | ||||||||||||||
| return None | ||||||||||||||
|
|
||||||||||||||
| all_temporal_ids = None | ||||||||||||||
| if temporal_ids is not None: | ||||||||||||||
| all_temporal_ids = [] | ||||||||||||||
| for t in temporal_ids: | ||||||||||||||
| all_temporal_ids.extend(t) | ||||||||||||||
|
Comment on lines
+1936
to
+1940
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its copied from original model: https://huggingface.co/openbmb/MiniCPM-V-4_5/blob/main/modeling_minicpmv.py#L94 |
||||||||||||||
|
|
||||||||||||||
| tgt_sizes = kwargs["tgt_sizes"] | ||||||||||||||
| pixel_values_list = pixel_values | ||||||||||||||
| vision_hidden_states = [] | ||||||||||||||
|
|
@@ -1963,7 +1984,7 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): | |||||||||||||
| pixel_values=all_pixel_values, patch_attention_mask=patch_attn_mask, position_ids=position_ids | ||||||||||||||
| )[0] | ||||||||||||||
| ) | ||||||||||||||
| vision_embedding = self.resampling(vision_embedding, tgt_sizes) | ||||||||||||||
| vision_embedding = self.resampling(vision_embedding, tgt_sizes, all_temporal_ids) | ||||||||||||||
|
|
||||||||||||||
| start = 0 | ||||||||||||||
| for pixel_value in pixel_values_list: | ||||||||||||||
|
|
@@ -1979,26 +2000,57 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): | |||||||||||||
| vision_hidden_states.append(dummy_feature) | ||||||||||||||
| return vision_hidden_states | ||||||||||||||
|
|
||||||||||||||
| def resampling(self, x, tgt_sizes): | ||||||||||||||
| def resampling(self, x, tgt_sizes, temporal_ids=None): | ||||||||||||||
| from itertools import chain | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be imported at the top of the file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this imports is used by minicpmv only, so i think it can be left here. e.g https://github.com/huggingface/optimum-intel/blob/main/optimum/intel/openvino/modeling_visual_language.py#L1229 |
||||||||||||||
|
|
||||||||||||||
| bs = x.shape[0] | ||||||||||||||
|
|
||||||||||||||
| patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] | ||||||||||||||
|
|
||||||||||||||
| self._adjust_pos_cache(tgt_sizes) | ||||||||||||||
|
|
||||||||||||||
| temporal_pos_emb = False | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me these names are a bit confusing: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i only created |
||||||||||||||
| temporal_ids_flatten = None | ||||||||||||||
| if temporal_ids is not None: | ||||||||||||||
| # example: [[-1], [-1], [2, 6, 9]] | ||||||||||||||
| temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) | ||||||||||||||
|
Comment on lines
+2040
to
+2041
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we actually need to do an additional flattening pass here? As I understand |
||||||||||||||
| max_temporal_size = max(temporal_ids_flatten) + 1 | ||||||||||||||
| if max_temporal_size > -1: | ||||||||||||||
| temporal_pos_emb = True | ||||||||||||||
| if max_temporal_size > self.max_temporal_size: | ||||||||||||||
| self._adjust_temporal_pos_cache(max_temporal_size, "cpu") | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a definition of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to align with original model. |
||||||||||||||
|
|
||||||||||||||
| max_patch_len = torch.max(patch_len) | ||||||||||||||
| key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool) | ||||||||||||||
|
|
||||||||||||||
| temporal_embed = None | ||||||||||||||
| pos_embed = [] | ||||||||||||||
| pos_embed_temporal = [] | ||||||||||||||
| for i in range(bs): | ||||||||||||||
| tgt_h, tgt_w = tgt_sizes[i] | ||||||||||||||
|
|
||||||||||||||
| if temporal_pos_emb: | ||||||||||||||
| if temporal_ids_flatten[i] == -1: | ||||||||||||||
| pos_embed_temporal.append(torch.zeros(self.embed_dim, dtype=torch.float32, device="cpu")) | ||||||||||||||
| else: | ||||||||||||||
| pos_embed_temporal.append(self.temporal_pos_embed[temporal_ids_flatten[i]].to(torch.float32)) # D | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||||||||||
|
|
||||||||||||||
| pos_embed.append(self._pos_embeds[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1))) # patches * D | ||||||||||||||
| key_padding_mask[i, patch_len[i] :] = True | ||||||||||||||
|
|
||||||||||||||
| pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( | ||||||||||||||
| 1, 0, 2 | ||||||||||||||
| ) # BLD => L * B * D | ||||||||||||||
| res = torch.from_numpy(self.resampler(image_feature=x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)) | ||||||||||||||
| if temporal_pos_emb: | ||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its copied from original model: https://huggingface.co/openbmb/MiniCPM-V-4_5/blob/main/resampler.py#L216 |
||||||||||||||
| temporal_embed = torch.stack(pos_embed_temporal, dim=0).unsqueeze(0) | ||||||||||||||
| res = torch.from_numpy( | ||||||||||||||
| self.resampler( | ||||||||||||||
| image_feature=x, | ||||||||||||||
| pos_embed=pos_embed, | ||||||||||||||
| key_padding_mask=key_padding_mask, | ||||||||||||||
| temporal_embed=temporal_embed, | ||||||||||||||
| ) | ||||||||||||||
| ) | ||||||||||||||
| return res | ||||||||||||||
|
|
||||||||||||||
| def _set_2d_pos_cache(self, max_size): | ||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -618,6 +618,28 @@ class OVCLIExportTestCase(unittest.TestCase): | |
| "resampler_model": {"int8": 6}, | ||
| }, | ||
| ), | ||
| ( | ||
| "image-text-to-text", | ||
| "minicpmv4", | ||
| "int4 --group-size 4 --ratio 0.8 --trust-remote-code", | ||
| { | ||
| "lm_model": {"int8": 12, "int4": 18}, | ||
| "text_embeddings_model": {"int8": 1}, | ||
| "vision_embeddings_model": {"int8": 14}, | ||
| "resampler_model": {"int8": 6}, | ||
| }, | ||
| ), | ||
| ( | ||
| "image-text-to-text", | ||
| "minicpmv4_5", | ||
| "int4 --group-size 4 --ratio 0.8 --trust-remote-code", | ||
| { | ||
| "lm_model": {"int8": 12, "int4": 18}, | ||
| "text_embeddings_model": {"int8": 1}, | ||
| "vision_embeddings_model": {"int8": 14}, | ||
| "resampler_model": {"int8": 6}, | ||
| }, | ||
| ), | ||
| ( | ||
| "image-text-to-text", | ||
| "internvl_chat", | ||
|
|
@@ -743,13 +765,13 @@ def _openvino_export(self, model_name: str, task: str, model_kwargs: Dict = None | |
|
|
||
| def test_filtered_architectures(cls): | ||
| if is_transformers_version("<", "4.49"): | ||
| expected = {"llama4", "qwen2_5_vl", "phi4mm"} | ||
| expected = {"llama4", "qwen2_5_vl", "phi4mm", "minicpmv4", "minicpmv4_5"} | ||
| elif is_transformers_version("<", "4.51"): | ||
| expected = {"llama4", "phi4mm"} | ||
| elif is_transformers_version("<", "4.52"): | ||
| expected = set() | ||
| else: | ||
| expected = {"llava-qwen2", "phi3_v", "phi4mm", "minicpmo"} | ||
| expected = {"llava-qwen2", "phi3_v", "phi4mm", "minicpmo", "minicpmv4", "minicpmv4_5"} | ||
|
Comment on lines
767
to
+774
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From this, I get an understanding that minicpmv4/minicpmv4_5 are supported for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont see any limitation on these 2 models. They can share same same version of transformers with minicpm-v-2.6 |
||
|
|
||
| all_model_type = {config[1] for config in cls.TRANSFORMERS_4BIT_CONFIGURATIONS} | ||
| filtered_model_type = {config[1] for config in cls.SUPPORTED_4BIT_CONFIGURATIONS} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.