diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 4f5b0a55a4..62f03f8b09 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -2516,7 +2516,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class DummyMiniCPMVResampleInputGenerator(DummyVisionInputGenerator): - SUPPORTED_INPUT_NAMES = ("image_feature", "pos_embed", "key_padding_mask") + SUPPORTED_INPUT_NAMES = ("image_feature", "pos_embed", "key_padding_mask", "temporal_embed") def __init__( self, @@ -2553,6 +2553,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int if input_name == "pos_embed": return self.random_float_tensor(shape=[self.feat_size, self.batch_size, self.hidden_size]) + if input_name == "temporal_embed": + return self.random_float_tensor(shape=[1, self.batch_size, self.hidden_size]) + class MiniCPMVConfigBehavior(str, enum.Enum): RESAMPLER = "resampler" @@ -2585,6 +2588,8 @@ def __init__( ) self._behavior = behavior self._orig_config = config + model_mapping = {2.6: "llama", 4.0: "qwen2", 4.5: "qwen3"} + self.model_type = model_mapping[self._orig_config.version] if self._behavior == MiniCPMVConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): self._config = config.vision_config self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyMiniCPMVImageInputGenerator,) @@ -2601,11 +2606,19 @@ def inputs(self) -> Dict[str, Dict[int, str]]: "position_ids": {0: "batch_size", 1: "patch_size"}, } if self._behavior == MiniCPMVConfigBehavior.RESAMPLER: - return { - "image_feature": {0: "batch_size", 1: "patch_height", 2: "patch_width"}, - "pos_embed": {0: "patch_size", 1: "batch_size", 2: "num_patches"}, - "key_padding_mask": {0: "batch_size", 1: "patch_size"}, - } + if self._orig_config.version == 4.5: + return { + "image_feature": {0: "batch_size", 1: "patch_height", 2: "patch_width"}, + "pos_embed": {0: "patch_size", 1: "batch_size", 2: "num_patches"}, + "key_padding_mask": {0: "batch_size", 1: "patch_size"}, + "temporal_embed": {0: "patch_size", 1: "batch_size"}, + } + else: + return { + "image_feature": {0: "batch_size", 1: "patch_height", 2: "patch_width"}, + "pos_embed": {0: "patch_size", 1: "batch_size", 2: "num_patches"}, + "key_padding_mask": {0: "batch_size", 1: "patch_size"}, + } return {} @property @@ -2631,10 +2644,20 @@ def with_behavior( behavior = MiniCPMVConfigBehavior(behavior) if behavior == MiniCPMVConfigBehavior.TEXT_EMBEDDINGS: - return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) + return get_vlm_text_embeddings_config( + self.model_type, + self._orig_config, + self.int_dtype, + self.float_dtype, + ) if behavior == MiniCPMVConfigBehavior.LANGUAGE: - return get_vlm_text_generation_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype) + return get_vlm_text_generation_config( + self.model_type, + self._orig_config, + self.int_dtype, + self.float_dtype, + ) if behavior == MiniCPMVConfigBehavior.VISION_EMBEDDINGS: return self.__class__( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index d1dbecf77b..04a4d2e59e 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -3323,6 +3323,27 @@ def _minicpmv_resampler_forward(self, image_feature, pos_embed, key_padding_mask out = self.attn(q_bs, image_feature + pos_embed, image_feature, key_padding_mask=key_padding_mask)[ 0 + ] # Q * B * D # L * B * D + L * B * Dpos_embed + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + +def _minicpmv4_5_resampler_forward(self, image_feature, pos_embed, key_padding_mask, temporal_embed): + image_feature = self.kv_proj(image_feature) # B * L * D + image_feature = self.ln_kv(image_feature).permute(1, 0, 2) # L * B * D + image_feature_emb = image_feature + pos_embed + image_feature_temporal = image_feature_emb + temporal_embed # [L, bs, D] + [1, bs, D] + bs = image_feature_temporal.shape[1] + q = self.ln_q(self.query) # Q * D + + q_bs = q.unsqueeze(1).repeat(1, bs, 1) + + out = self.attn(q_bs, image_feature_temporal, image_feature, key_padding_mask=key_padding_mask)[ + 0 ] # Q * B * D # L * B * D + L * B * D # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D @@ -3482,7 +3503,10 @@ def __init__( model_kwargs: Dict[str, Any], ): model.__orig_forward = model.forward - model.forward = types.MethodType(_minicpmv_resampler_forward, model) + has_temporal_ids = "temporal_ids" in inspect.signature(model.__orig_forward).parameters + model.forward = types.MethodType( + _minicpmv4_5_resampler_forward if has_temporal_ids else _minicpmv_resampler_forward, model + ) super().__init__(config, model, model_kwargs) diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index 2e97fdc4cd..d74564da2e 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -285,11 +285,21 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None: self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs} self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} - def forward(self, image_feature, pos_embed, key_padding_mask): + def forward(self, image_feature, pos_embed, key_padding_mask, temporal_embed=None): self.compile() - result = self.request( - {"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask} - )[0] + if temporal_embed is not None: + result = self.request( + { + "image_feature": image_feature, + "pos_embed": pos_embed, + "key_padding_mask": key_padding_mask, + "temporal_embed": temporal_embed, + } + )[0] + else: + result = self.request( + {"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask} + )[0] return result @@ -784,6 +794,7 @@ def forward( audio_embed_sizes=None, audio_attention_mask=None, input_mode=None, + temporal_ids=None, **kwargs, ): if pixel_values is None: @@ -809,6 +820,7 @@ def forward( audio_embed_sizes=audio_embed_sizes, audio_attention_mask=audio_attention_mask, input_mode=input_mode, + temporal_ids=temporal_ids, **kwargs, ) return self.language_model.forward( @@ -921,6 +933,7 @@ def prepare_inputs_for_generation( "input_audio_embeds": kwargs.get("input_audio_embeds", kwargs.get("audio_input_features")), "audio_embed_sizes": kwargs.get("audio_embed_sizes"), "input_mode": kwargs.get("input_mode"), + "temporal_ids": kwargs.get("temporal_ids"), } ) return model_inputs @@ -1923,10 +1936,18 @@ def __init__( max_size = self.config.vision_config.image_size // self.config.vision_config.patch_size self._pos_embeds = torch.from_numpy(self._get_2d_sincos_pos_embed(self.embed_dim, max_size)).float() self.max_size = (max_size, max_size) + self.max_temporal_size = 72000 - def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + def get_vision_embeddings(self, pixel_values, input_ids=None, temporal_ids=None, **kwargs): if input_ids is not None and input_ids.shape[1] == 1: return None + + all_temporal_ids = None + if temporal_ids is not None: + all_temporal_ids = [] + for t in temporal_ids: + all_temporal_ids.extend(t) + tgt_sizes = kwargs["tgt_sizes"] pixel_values_list = pixel_values vision_hidden_states = [] @@ -1963,7 +1984,7 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): pixel_values=all_pixel_values, patch_attention_mask=patch_attn_mask, position_ids=position_ids )[0] ) - vision_embedding = self.resampling(vision_embedding, tgt_sizes) + vision_embedding = self.resampling(vision_embedding, tgt_sizes, all_temporal_ids) start = 0 for pixel_value in pixel_values_list: @@ -1979,26 +2000,57 @@ def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): vision_hidden_states.append(dummy_feature) return vision_hidden_states - def resampling(self, x, tgt_sizes): + def resampling(self, x, tgt_sizes, temporal_ids=None): + from itertools import chain + bs = x.shape[0] patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] self._adjust_pos_cache(tgt_sizes) + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten) + 1 + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, "cpu") + max_patch_len = torch.max(patch_len) key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool) + temporal_embed = None pos_embed = [] + pos_embed_temporal = [] for i in range(bs): tgt_h, tgt_w = tgt_sizes[i] + + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append(torch.zeros(self.embed_dim, dtype=torch.float32, device="cpu")) + else: + pos_embed_temporal.append(self.temporal_pos_embed[temporal_ids_flatten[i]].to(torch.float32)) # D + pos_embed.append(self._pos_embeds[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1))) # patches * D key_padding_mask[i, patch_len[i] :] = True pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute( 1, 0, 2 ) # BLD => L * B * D - res = torch.from_numpy(self.resampler(image_feature=x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)) + if temporal_pos_emb: + temporal_embed = torch.stack(pos_embed_temporal, dim=0).unsqueeze(0) + res = torch.from_numpy( + self.resampler( + image_feature=x, + pos_embed=pos_embed, + key_padding_mask=key_padding_mask, + temporal_embed=temporal_embed, + ) + ) return res def _set_2d_pos_cache(self, max_size): diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index bba9c3b92a..d2503d7a21 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -618,6 +618,28 @@ class OVCLIExportTestCase(unittest.TestCase): "resampler_model": {"int8": 6}, }, ), + ( + "image-text-to-text", + "minicpmv4", + "int4 --group-size 4 --ratio 0.8 --trust-remote-code", + { + "lm_model": {"int8": 12, "int4": 18}, + "text_embeddings_model": {"int8": 1}, + "vision_embeddings_model": {"int8": 14}, + "resampler_model": {"int8": 6}, + }, + ), + ( + "image-text-to-text", + "minicpmv4_5", + "int4 --group-size 4 --ratio 0.8 --trust-remote-code", + { + "lm_model": {"int8": 12, "int4": 18}, + "text_embeddings_model": {"int8": 1}, + "vision_embeddings_model": {"int8": 14}, + "resampler_model": {"int8": 6}, + }, + ), ( "image-text-to-text", "internvl_chat", @@ -743,13 +765,13 @@ def _openvino_export(self, model_name: str, task: str, model_kwargs: Dict = None def test_filtered_architectures(cls): if is_transformers_version("<", "4.49"): - expected = {"llama4", "qwen2_5_vl", "phi4mm"} + expected = {"llama4", "qwen2_5_vl", "phi4mm", "minicpmv4", "minicpmv4_5"} elif is_transformers_version("<", "4.51"): expected = {"llama4", "phi4mm"} elif is_transformers_version("<", "4.52"): expected = set() else: - expected = {"llava-qwen2", "phi3_v", "phi4mm", "minicpmo"} + expected = {"llava-qwen2", "phi3_v", "phi4mm", "minicpmo", "minicpmv4", "minicpmv4_5"} all_model_type = {config[1] for config in cls.TRANSFORMERS_4BIT_CONFIGURATIONS} filtered_model_type = {config[1] for config in cls.SUPPORTED_4BIT_CONFIGURATIONS} diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 4cb154a0d9..44e671fdcb 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -938,6 +938,48 @@ class OVWeightCompressionTest(unittest.TestCase): "resampler_model": {"int8": 6}, }, ), + ( + OVModelForVisualCausalLM, + "minicpmv4", + True, + dict( + bits=4, + group_size=16, + dataset="contextual", + ratio=0.8, + sensitivity_metric="mean_activation_magnitude", + num_samples=1, + processor=MODEL_NAMES["minicpmv4"], + trust_remote_code=True, + ), + { + "lm_model": {"int8": 8, "int4": 22}, + "text_embeddings_model": {"int8": 1}, + "vision_embeddings_model": {"int8": 26}, + "resampler_model": {"int8": 6}, + }, + ), + ( + OVModelForVisualCausalLM, + "minicpmv4_5", + True, + dict( + bits=4, + group_size=16, + dataset="contextual", + ratio=0.8, + sensitivity_metric="mean_activation_magnitude", + num_samples=1, + processor=MODEL_NAMES["minicpmv4_5"], + trust_remote_code=True, + ), + { + "lm_model": {"int8": 8, "int4": 22}, + "text_embeddings_model": {"int8": 1}, + "vision_embeddings_model": {"int8": 26}, + "resampler_model": {"int8": 6}, + }, + ), ] # filter models type depending on min max transformers version @@ -964,6 +1006,8 @@ class OVWeightCompressionTest(unittest.TestCase): (OVModelForVisualCausalLM, "llava_next_video", False), (OVModelForVisualCausalLM, "minicpmv", True), (OVModelForVisualCausalLM, "qwen2_vl", False), + (OVModelForVisualCausalLM, "minicpmv4", True), + (OVModelForVisualCausalLM, "minicpmv4_5", True), ] if is_transformers_version("<", "4.54.0"): @@ -987,13 +1031,13 @@ class OVWeightCompressionTest(unittest.TestCase): def test_filtered_architectures(cls): if is_transformers_version("<", "4.49"): - expected = {"llama4", "qwen2_5_vl"} + expected = {"llama4", "qwen2_5_vl", "minicpmv4", "minicpmv4_5"} elif is_transformers_version("<", "4.51"): expected = {"llama4"} elif is_transformers_version("<", "4.52"): expected = set() else: - expected = {"llava-qwen2", "phi3_v", "minicpmo"} + expected = {"llava-qwen2", "phi3_v", "minicpmo", "minicpmv4", "minicpmv4_5"} all_model_type = {config[1] for config in cls.TRANSFORMERS_4BIT_CONFIGURATIONS} filtered_model_type = {config[1] for config in cls.LOAD_IN_4_BITS_SCOPE} diff --git a/tests/openvino/test_seq2seq.py b/tests/openvino/test_seq2seq.py index b4e7c43971..7ba472f21d 100644 --- a/tests/openvino/test_seq2seq.py +++ b/tests/openvino/test_seq2seq.py @@ -490,7 +490,7 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase): if is_transformers_version(">", "4.49"): SUPPORTED_ARCHITECTURES += ["gemma3", "smolvlm"] if is_transformers_version(">=", "4.51"): - SUPPORTED_ARCHITECTURES += ["llama4"] + SUPPORTED_ARCHITECTURES += ["llama4", "minicpmv4", "minicpmv4_5"] if is_transformers_version("<", "4.52"): SUPPORTED_ARCHITECTURES += ["minicpmo"] @@ -499,7 +499,17 @@ class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = set(SUPPORTED_ARCHITECTURES) - {"llava-qwen2", "phi3_v", "phi4mm"} TASK = "image-text-to-text" - REMOTE_CODE_MODELS = ["internvl_chat", "minicpmv", "minicpmo", "llava-qwen2", "phi3_v", "maira2", "phi4mm"] + REMOTE_CODE_MODELS = [ + "internvl_chat", + "minicpmv", + "minicpmv4", + "minicpmv4_5", + "minicpmo", + "llava-qwen2", + "phi3_v", + "maira2", + "phi4mm", + ] IMAGE = Image.open( requests.get( @@ -608,7 +618,7 @@ def test_compare_to_transformers(self, model_arch): self._check_device_and_request(ov_model, test_device, False) # pytorch minicpmv and internvl_chat are not designed to be used via forward - if model_arch not in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch not in ["minicpmv", "minicpmv4", "minicpmv4_5", "minicpmo", "internvl_chat"]: set_seed(SEED) ov_outputs = ov_model(**inputs) set_seed(SEED) @@ -655,7 +665,7 @@ def test_compare_to_transformers(self, model_arch): transformers_inputs["past_key_values"] = DynamicCache() with torch.no_grad(): - if model_arch in ["minicpmo"]: + if model_arch in ["minicpmo", "minicpmv4", "minicpmv4_5"]: # `generate` method for minicpmo requires tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS @@ -669,7 +679,7 @@ def test_compare_to_transformers(self, model_arch): transformers_outputs = transformers_outputs[1].sequences # original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them - if model_arch in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch in ["minicpmv", "minicpmv4", "minicpmv4_5", "minicpmo", "internvl_chat"]: ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :] self.assertTrue( torch.equal(ov_outputs, transformers_outputs), @@ -695,7 +705,7 @@ def test_compare_to_transformers(self, model_arch): transformers_inputs = copy.deepcopy(inputs) ov_outputs = ov_model.generate(**inputs, generation_config=gen_config) # original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them - if model_arch in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch in ["minicpmv", "minicpmv4", "minicpmv4_5", "minicpmo", "internvl_chat"]: ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :] with torch.no_grad(): transformers_outputs = transformers_model.generate( @@ -713,7 +723,7 @@ def test_compare_to_transformers(self, model_arch): transformers_inputs = copy.deepcopy(inputs) ov_outputs = ov_model.generate(**inputs, generation_config=gen_config) # original minicpmv, internvl always skip input tokens in generation results, while transformers based approach provide them - if model_arch in ["minicpmv", "minicpmo", "internvl_chat"]: + if model_arch in ["minicpmv", "minicpmv4", "minicpmv4_5", "minicpmo", "internvl_chat"]: ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :] with torch.no_grad(): transformers_outputs = transformers_model.generate( @@ -836,6 +846,9 @@ def test_generate_utils(self, model_arch): input_audio = self._generate_random_audio_data() question = "Translate this audio to French" inputs = model.preprocess_inputs(**preprocessors, text=question, audio=[input_audio]) + # skip the temporal_ids which makes the number of loop inconstant: + # https://huggingface.co/openbmb/MiniCPM-V-4_5/blob/main/resampler.py#L261 + inputs.pop("temporal_ids", None) outputs = model.generate(**inputs, max_new_tokens=10) # filter out original prompt becuase it may contains out of tokenizer tokens e.g. in nanollva text separator = -200 outputs = outputs[:, inputs["input_ids"].shape[1] :] diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 7801ba17bb..3ac828dd50 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -125,6 +125,8 @@ "minicpm": "optimum-intel-internal-testing/tiny-random-minicpm", "minicpm3": "optimum-intel-internal-testing/tiny-random-minicpm3", "minicpmv": "optimum-intel-internal-testing/tiny-random-minicpmv-2_6", + "minicpmv4": "optimum-intel-internal-testing/tiny-random-minicpm-v-4", + "minicpmv4_5": "optimum-intel-internal-testing/tiny-random-minicpm-v-4_5", "minicpmo": "optimum-intel-internal-testing/tiny-random-MiniCPM-o-2_6", "mistral": "optimum-intel-internal-testing/tiny-random-mistral", "mistral-nemo": "optimum-intel-internal-testing/tiny-random-mistral-nemo", @@ -288,6 +290,18 @@ "vision_embeddings_model": 26, "resampler_model": 6, }, + "minicpmv4": { + "lm_model": 30, + "text_embeddings_model": 1, + "vision_embeddings_model": 14, + "resampler_model": 6, + }, + "minicpmv4_5": { + "lm_model": 30, + "text_embeddings_model": 1, + "vision_embeddings_model": 14, + "resampler_model": 6, + }, "llava_next_video": { "lm_model": 30, "text_embeddings_model": 1,