diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 9c1684db81..628cac0f7e 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -138,13 +138,14 @@ Qwen2MoEPatcher, Qwen2VLLanguageModelPatcher, Qwen2VLVisionEmbMergerPatcher, + Qwen3VLVisionEmbMergerPatcher, + Qwen3VLLanguageModelPatcher, Qwen3MoeModelPatcher, QwenModelPatcher, SanaTextEncoderModelPatcher, XverseModelPatcher, ) - def init_model_configs(): if "open_clip" not in TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES: TasksManager._LIBRARY_TO_SUPPORTED_MODEL_TYPES["open_clip"] = {} @@ -164,6 +165,14 @@ def init_model_configs(): "transformers", "AutoModelForImageTextToText", ) + TasksManager._CUSTOM_CLASSES[("pt", "qwen3_vl", "image-text-to-text")] = ( + "transformers", + "AutoModelForImageTextToText", + ) + TasksManager._CUSTOM_CLASSES[("pt", "qwen3_vl_moe", "image-text-to-text")] = ( + "transformers", + "AutoModelForImageTextToText", + ) TasksManager._CUSTOM_CLASSES[("pt", "llava_next_video", "image-text-to-text")] = ( "transformers", "AutoModelForVision2Seq", @@ -333,6 +342,92 @@ def patch_model_for_export( ) -> "ModelPatcher": return OVDecoderModelPatcher(self, model, model_kwargs=model_kwargs) +class DummyQwen3VLLMInputGenerator(DummyTextInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "input_ids", + "attention_mask", + "encoder_attention_mask", + "token_type_ids", + "position_ids", + "visual_pos_masks", + "deepstack_visual_embeds", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + random_num_choices_range: Optional[Tuple[int, int]] = None, + padding_side: str = "right", + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + num_choices=num_choices, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + random_num_choices_range=random_num_choices_range, + padding_side=padding_side, + **kwargs, + ) + self.embed_dim = normalized_config.hidden_size + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32", bool_dtype: str = "bool"): + if input_name == "deepstack_visual_embeds": + return self.random_float_tensor([3, 2*self.sequence_length, self.embed_dim], framework=framework, dtype=float_dtype) + if input_name == "visual_pos_masks": + return self.constant_tensor( + shape=[self.batch_size, self.sequence_length], + framework=framework, + value=1, + dtype=DTYPE_MAPPER.pt(bool_dtype), + ) + return super().generate(input_name, framework, int_dtype, float_dtype) + +@register_in_tasks_manager( + "qwen3_vl_text", + *[ + "text-generation", + "text-generation-with-past", + ], + library_name="transformers", +) +@register_in_tasks_manager( + "qwen3_vl_moe_text", + *[ + "text-generation", + "text-generation-with-past", + ], + library_name="transformers", +) +class Qwen3VLTextOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + MIN_TRANSFORMERS_VERSION = "4.56.0" + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen3VLLMInputGenerator, GemmaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = super().inputs + common_inputs["visual_pos_masks"] = {0: "batch_size", 1: "sequence_length"} + common_inputs["deepstack_visual_embeds"] = {0: "num_layers", 1: "visual_seqlen"} + return common_inputs + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return OVDecoderModelPatcher(self, model, model_kwargs=model_kwargs) + + @register_in_tasks_manager( "qwen3_moe", @@ -3437,6 +3532,8 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return generated_input + + class DummyQwen2VLVisionEmbedInputGenerator(DummyVisionInputGenerator): SUPPORTED_INPUT_NAMES = ( "hidden_states", @@ -3503,6 +3600,75 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_int_tensor([hidden_size], max_value=hidden_size) +class DummyQwen3VLVisionEmbedInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ( + "hidden_states", + "attention_mask", + "window_attention_mask", + "window_index", + "rotary_pos_emb", + "input", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedVisionConfig, + batch_size: int = 1, + num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"], + width: int = 420, + height: int = 420, + **kwargs, + ): + self.batch_size = batch_size + self.height = height + self.width = width + self.num_channels = num_channels + self.temporal_patch_size = normalized_config.config.temporal_patch_size + self.patch_size = normalized_config.config.patch_size + if normalized_config.use_embed_dim: + self.embed_dim = ( + normalized_config.config.embed_dim + if hasattr(normalized_config.config, "embed_dim") + else normalized_config.hidden_size + ) + else: + self.embed_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size + self.num_heads = normalized_config.config.num_heads + self.spatial_merge_size = None + if hasattr(normalized_config.config, "spatial_merge_size"): + self.spatial_merge_size = normalized_config.config.spatial_merge_size + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size + grid_t = self.batch_size + + if input_name == "hidden_states": + return self.random_float_tensor( + [grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype + ) + + if input_name in ["attention_mask", "window_attention_mask"]: + return self.random_mask_tensor( + [1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype + ) + + if input_name == "rotary_pos_emb": + dim = self.embed_dim // self.num_heads // 2 + return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype) + + if input_name == "input": + return self.constant_tensor([4, 2520], framework=framework, value=0, dtype=DTYPE_MAPPER.pt(int_dtype)) + + if input_name == "window_index": + if self.spatial_merge_size is None: + raise ValueError( + "`spatial_merge_size` parameter is not found in model config. Can not generate dummy input data for `window_index` input" + ) + spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size + hidden_size = (grid_t * grid_h * grid_w) // spatial_merge_unit + return self.random_int_tensor([hidden_size], max_value=hidden_size) + class Qwen2VLConfigBehavior(str, enum.Enum): LANGUAGE = "language" VISION_EMBEDDINGS = "vision_embeddings" @@ -3674,6 +3840,241 @@ def patch_model_for_export( if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER: return Qwen2_5_VLVisionEmbMergerPatcher(self, model, model_kwargs) return super().patch_model_for_export(model, model_kwargs) + +class Qwen3VLConfigBehavior(str, enum.Enum): + LANGUAGE = "language" + VISION_EMBEDDINGS = "vision_embeddings" + VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger" + TEXT_EMBEDDINGS = "text_embeddings" + VISION_EMBEDDINGS_POS = "vision_embeddings_pos" + +@register_in_tasks_manager( + "qwen3_vl", + *["image-text-to-text", "video-text-to-text"], + library_name="transformers", +) +class Qwen3_VLOpenVINOConfig(BaseVLMOpenVINOConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen3VLConfigBehavior] + NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen3VLVisionEmbedInputGenerator,) + MIN_TRANSFORMERS_VERSION = version.parse("4.56.0") + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: Qwen3VLConfigBehavior = Qwen3VLConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + self._behavior = behavior + self._orig_config = config + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self._normalized_config.use_embed_dim = False + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self._normalized_config.use_embed_dim = True + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self._normalized_config.use_embed_dim = True + + + + + @staticmethod + def get_model_for_behavior(model, behavior: Union[str, Qwen3VLConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, Qwen3VLConfigBehavior): + behavior = Qwen3VLConfigBehavior(behavior) + + if behavior == Qwen3VLConfigBehavior.LANGUAGE: + return model + + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS: + vision_embeddings = model.visual.patch_embed + vision_embeddings.config = model.config.vision_config + return vision_embeddings + + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + vision_emb_merger = model.visual + vision_emb_merger.config = model.config.vision_config + return vision_emb_merger + + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + vision_emb_pos = model.visual.pos_embed + vision_emb_pos.config = model.config.vision_config + return vision_emb_pos + + if behavior == Qwen3VLConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = ( + model.model.embed_tokens if hasattr(model.model, "embed_tokens") else model.language_model.embed_tokens + ) + text_embedding.config = model.config + return text_embedding + + def with_behavior( + self, + behavior: Union[str, Qwen3VLConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, Qwen3VLConfigBehavior): + behavior = Qwen3VLConfigBehavior(behavior) + + if behavior == Qwen3VLConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config("qwen3_vl_text", self._orig_config.text_config, self.int_dtype, self.float_dtype) + + if behavior == Qwen3VLConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen3_vl_text", + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Qwen3VLLanguageModelPatcher, + dummy_input_generator=DummyQwen2VLLMInputGenerator, + inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, + ) + + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ): + model_kwargs = model_kwargs or {} + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return Qwen3VLVisionEmbMergerPatcher(self, model, model_kwargs) + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS or self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + return ModelPatcher(self, model, model_kwargs=model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS: + return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}} + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return { + "hidden_states": {0: "sequence_length"}, + "attention_mask": {1: "sequence_length", 2: "sequence_length"}, + "rotary_pos_emb": {0: "sequence_length"}, + } + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + return { + "input": {1: "sequence_length"}, + } + + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS: + return {"last_hidden_state": {0: "seq_len"}} + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return {"last_hidden_state": {0: "seq_len"}, "deepstack_feature_lists": {0: "seq_len"}} + if self._behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + return {"last_hidden_state": {0: "seq_len", 1: "seq_len"}} + return {} + + +@register_in_tasks_manager( + "qwen3_vl_moe", + *["image-text-to-text", "video-text-to-text"], + library_name="transformers", +) +class Qwen3_VL_MOEOpenVINOConfig(Qwen3_VLOpenVINOConfig): + def with_behavior( + self, + behavior: Union[str, Qwen3VLConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, Qwen3VLConfigBehavior): + behavior = Qwen3VLConfigBehavior(behavior) + + if behavior == Qwen3VLConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config("qwen3_vl_moe_text", self._orig_config.text_config, self.int_dtype, self.float_dtype) + + if behavior == Qwen3VLConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen3_vl_moe_text", + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Qwen3VLLanguageModelPatcher, + dummy_input_generator=DummyQwen2VLLMInputGenerator, + inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, + ) + + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + if behavior == Qwen3VLConfigBehavior.VISION_EMBEDDINGS_POS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) @register_in_tasks_manager( diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 851308e29e..74f1540c05 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -4337,6 +4337,42 @@ def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self._model.forward = self._model.__orig_forward + +class Qwen3VLLanguageModelPatcher(OVDecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + + # Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L2156-L2178 + # moved audio and vision features processing outside model + def lm_forward(self, attention_mask, position_ids, past_key_values, inputs_embeds, visual_pos_masks, deepstack_visual_embeds, use_cache=True): + from transformers.cache_utils import DynamicCache + + pkv = DynamicCache.from_legacy_cache(past_key_values) + outputs = self.model.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + past_key_values=pkv, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states) + return (logits, outputs.past_key_values.to_legacy_cache()) + + model.__orig_forward = model.forward + model.forward = types.MethodType(lm_forward, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward def patch_qwen2vl_vision_blocks(model, force_new_behaviour=False): if not force_new_behaviour and is_transformers_version("<=", "4.48.99"): @@ -4550,6 +4586,46 @@ def __exit__(self, exc_type, exc_value, traceback): for block in self._model.blocks: block.forward = block._orig_forward block.attn.forward = block.attn._orig_forward + +class Qwen3VLVisionEmbMergerPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + # Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118 + # added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len) + # separated patch_embed and rot_pos_emb calls for performing as part of another model + def image_embed_forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor + ) -> torch.Tensor: + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + hidden_states + ) + deepstack_feature_lists.append(deepstack_feature) + last_hidden_state = self.merger(hidden_states) + return last_hidden_state, torch.stack(deepstack_feature_lists, dim=0) + + model.forward = types.MethodType(image_embed_forward, model) + super().__init__(config, model, model_kwargs) + + def __enter__(self): + patch_qwen2vl_vision_blocks(self._model) + super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + for block in self._model.blocks: + block.forward = block._orig_forward + block.attn.forward = block.attn._orig_forward # copied from https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/granitemoe/modeling_granitemoe.py#L321 diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index d1318fc109..5d04ad3585 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -228,6 +228,8 @@ def get_submodels(model): "phi3_v", "qwen2_vl", "qwen2_5_vl", + "qwen3_vl", + "qwen3_vl_moe", "got_ocr2", "gemma3", "idefics3", diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index b88d7097a7..f6f2072fa3 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -144,6 +144,9 @@ def prepare_inputs( position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, + visual_pos_masks: Optional[torch.FloatTensor] = None, + deepstack_visual_embeds: Optional[torch.FloatTensor] = None, + **kwargs, ): batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0] @@ -186,11 +189,24 @@ def prepare_inputs( if past_len: position_ids = position_ids[:, -inputs_embeds.shape[1] :] - if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3: + if (self.config.model_type == "qwen2_vl" or self.config.model_type == "qwen3_vl" or self.config.model_type == "qwen3_vl_moe") and position_ids.ndim != 3: position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0) inputs["position_ids"] = position_ids + if "visual_pos_masks" in self.input_names: + if visual_pos_masks is not None: + inputs["visual_pos_masks"] = visual_pos_masks + else: + inputs["visual_pos_masks"] = torch.zeros(1, 1, dtype=torch.bool) + + if "deepstack_visual_embeds" in self.input_names: + num_layers = len(self.config.vision_config.deepstack_visual_indexes) + emd_dim = self.config.text_config.hidden_size + if isinstance(deepstack_visual_embeds, list): + inputs["deepstack_visual_embeds"] = torch.Tensor(deepstack_visual_embeds) + else: + inputs["deepstack_visual_embeds"] = torch.zeros((num_layers, 1, emd_dim), dtype=torch.float32) if "token_type_ids" in self.input_names: if token_type_ids is None: token_type_ids = np.zeros(inputs_embeds.shape[:2], dtype=int) @@ -200,7 +216,6 @@ def prepare_inputs( inputs["beam_idx"] = ( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) - return inputs def forward( @@ -210,16 +225,19 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, + visual_pos_masks: Optional[torch.FloatTensor] = None, + deepstack_visual_embeds: Optional[torch.FloatTensor] = None, **kwargs, ): self.compile() - inputs = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, inputs_embeds=inputs_embeds, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, **kwargs, ) # Run inference @@ -332,6 +350,7 @@ def forward(self, audio_feature, audio_mask): "vision_resampler": OVVisionResampler, "multi_modal_projector": OVMultiModalProjector, "vision_embeddings_merger": OVVisionEmbedding, + "vision_embeddings_pos": OVVisionProjection, "audio_embeddings": OVAudioEmbeddings, "audio_forward_embeddings": OVAudioEmbeddings, "audio_encoder": OVAudioEncoder, @@ -767,38 +786,75 @@ def forward( ): if pixel_values is None: pixel_values = images if images is not None else image_pixel_values - inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings( - input_ids, - pixel_values, - image_sizes=image_sizes, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - image_bound=image_bound, - tgt_sizes=tgt_sizes, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - rope_deltas=rope_deltas, - second_per_grid_ts=second_per_grid_ts, - pixel_attention_mask=pixel_attention_mask, - input_image_embeds=input_image_embeds, - image_attention_mask=image_attention_mask, - input_audio_embeds=input_audio_embeds if input_audio_embeds is not None else audio_input_features, - audio_embed_sizes=audio_embed_sizes, - audio_attention_mask=audio_attention_mask, - input_mode=input_mode, - **kwargs, - ) - return self.language_model.forward( - input_ids=None, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - token_type_ids=token_type_ids, - past_key_values=past_key_values, - **kwargs, - ) + if self.config.model_type == "qwen3_vl" or self.config.model_type == "qwen3_vl_moe": + inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds = self.get_multimodal_embeddings( + input_ids, + pixel_values, + inputs_embeds=inputs_embeds, + image_sizes=image_sizes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + image_bound=image_bound, + tgt_sizes=tgt_sizes, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + second_per_grid_ts=second_per_grid_ts, + pixel_attention_mask=pixel_attention_mask, + input_image_embeds=input_image_embeds, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds if input_audio_embeds is not None else audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + input_mode=input_mode, + **kwargs, + ) + return self.language_model.forward( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + past_key_values=past_key_values, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + else: + inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings( + input_ids, + pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + image_bound=image_bound, + tgt_sizes=tgt_sizes, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + rope_deltas=rope_deltas, + second_per_grid_ts=second_per_grid_ts, + pixel_attention_mask=pixel_attention_mask, + input_image_embeds=input_image_embeds, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds if input_audio_embeds is not None else audio_input_features, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + input_mode=input_mode, + **kwargs, + ) + return self.language_model.forward( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + token_type_ids=token_type_ids, + past_key_values=past_key_values, + **kwargs, + ) def _reorder_cache(self, past_key_values, beam_idx): return self.language_model._reorder_cache(past_key_values, beam_idx) @@ -923,6 +979,13 @@ def preprocess_inputs( Preprocess input instruction and an image. """ + # modified from https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/generation/utils.py#L1992 + def _prepare_cache_for_generation(self, *args, **kwargs): + """ + This function is used to prepare the cache : when calling `generate` before the first inference, an instance of `DynamicCache` will be created. + For OVModel, we don't want model_kwargs to be updated before generation. + """ + return class _OVLlavaForCausalLM(OVModelForVisualCausalLM): def __init__( @@ -2487,7 +2550,6 @@ class QWen2VLModelOutputWithPast(ModelOutput): rope_deltas: Optional[torch.FloatTensor] = None second_per_grid_ts: Optional[torch.FloatTensor] = None - class _OVQwen2VLForCausalLM(OVModelForVisualCausalLM): additional_parts = ["vision_embeddings_merger"] @@ -3353,7 +3415,57 @@ def preprocess_inputs( inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt") return inputs - # Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1602 + return model_kwargs + +class _OVQwen3VLForCausalLM(OVModelForVisualCausalLM): + additional_parts = ["vision_embeddings_merger", "vision_embeddings_pos"] + + def __init__( + self, + language_model: ov.Model, + text_embeddings: ov.Model, + vision_embeddings: ov.Model, + config: PretrainedConfig = None, + device: str = "CPU", + dynamic_shapes: bool = None, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + super().__init__( + language_model=language_model, + text_embeddings=text_embeddings, + vision_embeddings=vision_embeddings, + config=config, + device=device, + dynamic_shapes=dynamic_shapes, + ov_config=ov_config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + **kwargs, + ) + self.rope_deltas = None # cache rope_deltas here + + if is_transformers_version(">=", "4.56.0"): + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionRotaryEmbedding as VisionRotaryEmbedding, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionRotaryEmbedding + + self._rotary_pos_emb = VisionRotaryEmbedding( + self.config.vision_config.hidden_size // self.config.vision_config.num_heads // 2 + ) + self.num_grid_per_side = int(config.vision_config.num_position_embeddings**0.5) + self.spatial_merge_size = config.vision_config.spatial_merge_size + head_dim = config.vision_config.hidden_size // config.vision_config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2) + + else: + raise ValueError( + f"Initialization model for {self.config.model_type} required at least transformers >= 4.45" + ) + def _update_model_kwargs_for_generation( self, outputs: ModelOutput, @@ -3374,6 +3486,545 @@ def _update_model_kwargs_for_generation( return model_kwargs + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None: + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + } + ) + return model_inputs + + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids.""" + + # Since we use timestamps to seperate videos, like , the video_grid_thw should also be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + # t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list) + weight_tensor = torch.tensor(weight_list) + pos_embeds = torch.from_numpy(self.vision_embeddings_pos(idx_tensor)) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.vision_config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: Optional[torch.FloatTensor] = None, + video_features: Optional[torch.FloatTensor] = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel(): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + + def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs): + hidden_states = torch.from_numpy(self.vision_embeddings(pixel_values)[0]) + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) + causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf")) + + res = self.vision_embeddings_merger( + pixel_values=hidden_states, attention_mask=causal_mask, rotary_pos_emb=rotary_pos_emb + ) + return res[0], res[1] + + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + # pixel_values = pixel_values.type(self.visual.dtype) + image_embeds, deepstack_image_embeds = self.get_vision_embeddings(pixel_values, image_grid_thw) + image_embeds, deepstack_image_embeds = torch.from_numpy(image_embeds), torch.from_numpy(deepstack_image_embeds) + deepstack_image_embeds = deepstack_image_embeds.tolist() + split_sizes = (image_grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. The deepstack visual features are also returned. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + video_embeds = self.get_vision_embeddings(pixel_values_videos, video_grid_thw) + video_embeds, deepstack_video_embeds = torch.from_numpy(video_embeds[0]), torch.from_numpy(video_embeds[1]) + split_sizes = (video_grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds, deepstack_video_embeds + + def get_multimodal_embeddings( + self, + input_ids, + pixel_values=None, + attention_mask=None, + position_ids=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + cache_position=None, + **kwargs, + ): + image_mask = None + video_mask = None + inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) + + if pixel_values is not None: + image_embeds, deepstack_image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds, deepstack_video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # aggregate visual_pos_masks and deepstack_visual_embeds + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds): + embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + if position_ids is None: + attention_mask_tensor = ( + attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"] + ) + if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4: + attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2) + # Only apply conversion for floating point tensors (inverted masks) + if attention_mask_tensor.dtype.is_floating_point: + attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min + attention_mask_tensor = (1.0 - attention_mask_tensor).int() + + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + if self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask_tensor, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + return inputs_embeds, attention_mask, position_ids, visual_pos_masks, deepstack_visual_embeds + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if audio is not None: + raise ValueError("Audio input is not supported") + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ] + if image is not None: + conversation[0]["content"].insert(0, {"type": "image"}) + if video is not None: + conversation[0]["content"].insert(0, {"type": "video"}) + + text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt") + return inputs + + + def forward( + self, + input_ids, + pixel_values=None, + past_key_values=None, + inputs_embeds=None, + image_sizes=None, + attention_mask=None, + position_ids=None, + image_bound=None, + tgt_sizes=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + rope_deltas=None, + **kwargs, + ): + result = super().forward( + input_ids, + pixel_values, + past_key_values, + inputs_embeds, + image_sizes, + attention_mask, + position_ids, + image_bound, + tgt_sizes, + pixel_values_videos, + image_grid_thw, + video_grid_thw, + rope_deltas, + **kwargs, + ) + final_result = QWen2VLModelOutputWithPast( + logits=result.logits, past_key_values=result.past_key_values, rope_deltas=rope_deltas + ) + return final_result + + class _OVMaira2ForCausalLM(_OVLlavaForCausalLM): @staticmethod def preprocess_inputs( @@ -4349,6 +5000,8 @@ def preprocess_inputs( "internvl_chat": _OVInternVLForCausalLM, "qwen2_vl": _OVQwen2VLForCausalLM, "qwen2_5_vl": _OVQwen2_5_VLForCausalLM, + "qwen3_vl": _OVQwen3VLForCausalLM, + "qwen3_vl_moe": _OVQwen3VLForCausalLM, "got_ocr2": _OVGotOCR2ForCausalLM, "gemma3": _OVGemma3ForCausalLM, "idefics3": _OVIdefics3ForCausalLM, diff --git a/setup.py b/setup.py index 4af8f58123..622d276b71 100644 --- a/setup.py +++ b/setup.py @@ -28,8 +28,8 @@ INSTALL_REQUIRE = [ "torch>=1.11", - "optimum==1.27.*", - "transformers>=4.36,<4.54", + "optimum", + "transformers>=4.36", "datasets>=1.4.0", "setuptools", "scipy",