diff --git a/docs/source/onnx/overview.mdx b/docs/source/onnx/overview.mdx index 01f6ed1c..01c68bdf 100644 --- a/docs/source/onnx/overview.mdx +++ b/docs/source/onnx/overview.mdx @@ -52,6 +52,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra - ESM - Falcon - Flaubert +- Gemma +- Gemma3 - GPT-2 - GPT-BigCode - GPT-J diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 9c4a660e..420781a7 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -124,6 +124,8 @@ class OnnxConfig(ExporterConfig, ABC): "image-to-image": OrderedDict( {"reconstruction": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}} ), + "image-text-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), + "image-text-to-text-with-past": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}), "keypoint-detection": OrderedDict( {"heatmaps": {0: "batch_size", 1: "num_keypoints", 2: "height", 3: "width"}} ), diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 951131b2..8a958f5f 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -15,13 +15,16 @@ from __future__ import annotations +import enum from collections import OrderedDict from collections.abc import Iterable from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Self from optimum.exporters.onnx.base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from optimum.exporters.onnx.constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME +from optimum.exporters.onnx.model_patcher import ModelPatcher +from optimum.exporters.tasks import TasksManager from optimum.utils import ( DummyAudioInputGenerator, DummyBboxInputGenerator, @@ -465,3 +468,258 @@ def post_process_exported_models( models_and_onnx_configs[ONNX_DECODER_WITH_PAST_NAME][1]._decoder_onnx_config.is_merged = True return models_and_onnx_configs, onnx_files_subpaths + + +class VLMConfigBehavior(str, enum.Enum): + """Specifies the behavior of the [`~exporters.onnx.base.VLMDecoderOnnxConfig`]. + + - MONOLITH: the config can be used to export the entire multimodal model as a single file. + - VISION_ENCODER: the config can be used to export the underlying vision encoder. + - MULTIMODAL_PROJECTOR: the config can be used to export the underlying multimodal projector. + - TEXT_ENCODER: the config can be used to export the underlying text encoder, mapping inputs ids to embeddings. + - LANGUAGE_MODEL: the config can be used to export the underlying language model. Note: this does not + include the language model head. + """ + + MONOLITH = "monolith" + VISION_ENCODER = "vision_encoder" + MULTIMODAL_PROJECTOR = "multimodal_projector" + TEXT_ENCODER = "text_encoder" + LANGUAGE_MODEL = "language_model" + LANGUAGE_MODEL_WITH_HEAD = "language_model_with_head" + + +class VLMDecoderOnnxConfig(TextDecoderOnnxConfig): + """Base config for decoder-based vision language models.""" + + DUMMY_INPUT_GENERATOR_CLASSES = TextAndVisionOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + SUPPORTED_BEHAVIORS: ClassVar[list[VLMConfigBehavior]] = list(VLMConfigBehavior) + + def __init__( + self, + config: PretrainedConfig, + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + use_past: bool = False, + use_past_in_inputs: bool = False, + preprocessors: list[Any] | None = None, + legacy: bool = False, + behavior: VLMConfigBehavior = VLMConfigBehavior.MONOLITH, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + use_past=use_past, + use_past_in_inputs=use_past_in_inputs, + preprocessors=preprocessors, + legacy=legacy, + ) + self._behavior = behavior + + @property + def behavior(self) -> VLMConfigBehavior: + """The behavior property.""" + return self._behavior + + @behavior.setter + def behavior(self, value: str | VLMConfigBehavior) -> None: + if isinstance(value, str): + try: + value = VLMConfigBehavior(value) + except ValueError: + raise ValueError( + f"behavior must be one of {self.SUPPORTED_BEHAVIORS}, but got {value} instead." + ) from None + + self._behavior = value + + def get_supported_behaviors(self, task: str) -> list[VLMConfigBehavior]: + """Get supported behaviors for this model. + + The supported behaviors are task-dependent. For instance, "text-generation" is handled by + the language model and associated head. + """ + if "image-text-to-text" in task: + # All parts of the model + return [ + VLMConfigBehavior.VISION_ENCODER, + VLMConfigBehavior.MULTIMODAL_PROJECTOR, + VLMConfigBehavior.TEXT_ENCODER, + VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD, + ] + + elif "text-generation" in task: + # Only text-related components needed + return [ + VLMConfigBehavior.TEXT_ENCODER, + VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD, + ] + + elif "feature-extraction" in task: + # Same as image-text-to-text but without the LM head + return [ + VLMConfigBehavior.VISION_ENCODER, + VLMConfigBehavior.MULTIMODAL_PROJECTOR, + VLMConfigBehavior.TEXT_ENCODER, + VLMConfigBehavior.LANGUAGE_MODEL, + ] + + else: + message = f"Invalid task for {self.__class__.__name__}: {task}" + raise ValueError(message) + + def with_behavior(self, behavior: VLMConfigBehavior) -> Self: + if behavior == VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD: + model_config = self._config.text_config + model_type = model_config.model_type + + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + raise ValueError( + f"Unsupported language model type provided `{model_type}`. Please define custom export config" + ) + + lm_task = "text-generation-with-past" if self.use_past else "text-generation" + exporter_config_constructor = TasksManager.get_exporter_config_constructor( + exporter="onnx", + model_type=model_type, + task=lm_task, + ) + return exporter_config_constructor( + model_config, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=self.use_past, + use_past_in_inputs=self.use_past_in_inputs, + ) + + elif behavior in [ + VLMConfigBehavior.MONOLITH, + VLMConfigBehavior.TEXT_ENCODER, + VLMConfigBehavior.VISION_ENCODER, + VLMConfigBehavior.MULTIMODAL_PROJECTOR, + VLMConfigBehavior.LANGUAGE_MODEL, + ]: + # TODO: check if we need to handle vision encoder part similarly, with config.vision_config + return type(self)( + config=self._config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + use_past=self.use_past, + use_past_in_inputs=self.use_past_in_inputs, + preprocessors=self._preprocessors, + legacy=self.legacy, + behavior=behavior, + ) + + message = f"Behavior must be one of {self.SUPPORTED_BEHAVIORS}, but got {behavior} instead." + raise ValueError(message) + + def get_model_for_behavior(self, model: PreTrainedModel, behavior: VLMConfigBehavior): + if behavior != self.behavior: + raise ValueError( + f"Config behavior {self.behavior} does not match the requested behavior {behavior}. Please run `.with_behavior` first." + ) + + if behavior == VLMConfigBehavior.LANGUAGE_MODEL: + return model.language_model + + if behavior == VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD: + # No default way to get just the LM and LM head, so we get entire model. + return model + + if behavior == VLMConfigBehavior.VISION_ENCODER: + vision_encoder = model.vision_tower + vision_encoder.config = model.config.vision_config + return vision_encoder + + if behavior == VLMConfigBehavior.MULTIMODAL_PROJECTOR: + multi_modal_projector = model.multi_modal_projector + # TODO: check if multimodal projector actually acceps the base config, not config.vision_config + multi_modal_projector.config = model.config + return multi_modal_projector + + if behavior == VLMConfigBehavior.MONOLITH: + return model + + if behavior == VLMConfigBehavior.TEXT_ENCODER: + return model.get_input_embeddings() + + message = f"Behavior must be one of {self.SUPPORTED_BEHAVIORS}, but got {behavior} instead." + raise ValueError(message) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + if self.behavior == VLMConfigBehavior.VISION_ENCODER: + return {"pixel_values": {0: "batch_size"}} + + if self.behavior == VLMConfigBehavior.MULTIMODAL_PROJECTOR: + # Should be batch_size, number of tokens per image, and hidden size of the vision encoder + return { + "vision_outputs": { + 0: "batch_size", + 1: "num_patch_tokens", + 2: "hidden_size", + } + } + + if self.behavior in ( + VLMConfigBehavior.LANGUAGE_MODEL, + VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD, + ): + return super().inputs + + if self.behavior == VLMConfigBehavior.MONOLITH: + inputs = super().inputs + + # text-generation task should not include images. + if "image-text-to-text" in self.task: + # No need to add channel and image dimensions + inputs["pixel_values"] = {0: "batch_size"} + + return inputs + + if self.behavior == VLMConfigBehavior.TEXT_ENCODER: + return super().inputs + + message = f"Behavior must be one of {self.SUPPORTED_BEHAVIORS}, but got {self.behavior} instead." + raise ValueError(message) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + if self.behavior in ( + VLMConfigBehavior.VISION_ENCODER, + VLMConfigBehavior.LANGUAGE_MODEL, + ): + return {"last_hidden_state": {0: "batch_size"}} + + if self.behavior == VLMConfigBehavior.MULTIMODAL_PROJECTOR: + return { + "image_features": { + 0: "batch_size", + 1: "mm_tokens_per_image", + 2: "text_hidden_size", + } + } + + if self.behavior == VLMConfigBehavior.TEXT_ENCODER: + return {"inputs_embeds": {0: "batch_size", 1: "sequence_length"}} + + return super().outputs + + def patch_model_for_export( + self, model: PreTrainedModel, model_kwargs: dict[str, Any] | None = None + ) -> ModelPatcher: + if self.behavior in ( + VLMConfigBehavior.LANGUAGE_MODEL, + VLMConfigBehavior.LANGUAGE_MODEL_WITH_HEAD, + VLMConfigBehavior.MONOLITH, + ): + if model_kwargs is None: + model_kwargs = {} + model_kwargs["use_cache"] = self.use_past + + return super().patch_model_for_export(model=model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/onnx/constants.py b/optimum/exporters/onnx/constants.py index a996bac5..ab0729b6 100644 --- a/optimum/exporters/onnx/constants.py +++ b/optimum/exporters/onnx/constants.py @@ -38,3 +38,5 @@ "musicgen", "whisper", ] + +VLM_TEXT_GENERATION_MODELS = ["gemma3"] diff --git a/optimum/exporters/onnx/input_generators.py b/optimum/exporters/onnx/input_generators.py new file mode 100644 index 00000000..88872e4b --- /dev/null +++ b/optimum/exporters/onnx/input_generators.py @@ -0,0 +1,194 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom input generators for ONNX export configs.""" + +from typing import Optional, cast + +import torch + +from optimum.utils import ( + DEFAULT_DUMMY_SHAPES, + DummyTextInputGenerator, + NormalizedTextConfig, +) + + +class Gemma3DummyInputGenerator(DummyTextInputGenerator): + """Dummy input generator for Gemma3.""" + + SUPPORTED_INPUT_NAMES = ( + "input_ids", + "attention_mask", + "pixel_values", + "vision_outputs", + ) + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + num_choices: int = DEFAULT_DUMMY_SHAPES["num_choices"], + random_batch_size_range: Optional[tuple[int, int]] = None, + random_sequence_length_range: Optional[tuple[int, int]] = None, + random_num_choices_range: Optional[tuple[int, int]] = None, + padding_side: str = "right", + **kwargs, + ): + super().__init__( + task, + normalized_config, + batch_size, + sequence_length, + num_choices, + random_batch_size_range, + random_sequence_length_range, + random_num_choices_range, + padding_side, + **kwargs, + ) + + # Gemma3 default image size + self.height = self.width = 896 + self.n_channels = 3 + self.padding = "left" + self.image_token_index = int(self.normalized_config.image_token_index) + self.mm_tokens_per_image = int(self.normalized_config.mm_tokens_per_image) + self.boi_token_index = self.normalized_config.boi_token_index + self.eoi_token_index = self.normalized_config.eoi_token_index + self.image_size = self.normalized_config.vision_config.image_size + self.patch_size = self.normalized_config.vision_config.patch_size + + def _generate_pixel_values( + self, + framework: str, + float_dtype: str, + ): + """Generate random pixel values.""" + shape = [self.batch_size, self.n_channels, self.height, self.width] + min_value = -1 + # See Gemma3ImageProcessor's `rescale_factor`. + max_value = 1 / 255 + return self.random_float_tensor( + shape=shape, + min_value=min_value, + max_value=max_value, + framework=framework, + dtype=float_dtype, + ) + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + if input_name == "pixel_values": + # Vision encoder input + return self._generate_pixel_values(framework, float_dtype) + + elif input_name == "vision_outputs": + # Multimodal projector input + patches_per_image = int(self.image_size // self.patch_size) ** 2 + shape = [self.batch_size, patches_per_image, self.sequence_length] + return self.random_float_tensor(shape=shape, framework=framework, dtype=float_dtype) + + elif input_name == "image_features": + shape = [ + self.batch_size, + self.mm_tokens_per_image, + self.normalized_config.text_config.hidden_size, + ] + return self.random_float_tensor( + shape=shape, + framework=framework, + dtype=float_dtype, + ) + + generated_inputs = super().generate(input_name, framework, int_dtype, float_dtype) + image_size = (self.batch_size, self.mm_tokens_per_image) + if input_name == "input_ids": + # Add image tokens corresponding to mm_tokens_per_image's per image + if framework == "pt": + input_ids = cast(torch.Tensor, generated_inputs) + image_tokens = torch.full( + size=image_size, + fill_value=self.image_token_index, + dtype=input_ids.dtype, + ) + return torch.cat((image_tokens, input_ids), dim=1) + elif framework == "tf": + import tensorflow as tf + + input_ids = cast(tf.Tensor, generated_inputs) + + image_tokens = tf.fill( + dims=image_size, + value=self.image_token_index, + ) + return tf.concat((image_tokens, generated_inputs), axis=1) + elif framework == "np": + import numpy as np + + input_ids = cast(np.ndarray, generated_inputs) + image_tokens = np.full( + shape=image_size, + fill_value=self.image_token_index, + dtype=input_ids.dtype, + ) + return np.concatenate((image_tokens, generated_inputs), axis=1) + + elif input_name == "attention_mask": + # Add attention mask for image tokens + if framework == "pt": + attention_mask = cast(torch.Tensor, generated_inputs) + image_attention_mask = torch.ones( + size=image_size, + dtype=attention_mask.dtype, + ) + if self.padding == "right": + return torch.cat((attention_mask, image_attention_mask), dim=1) + else: + return torch.cat((image_attention_mask, attention_mask), dim=1) + elif framework == "tf": + import tensorflow as tf + + attention_mask = cast(tf.Tensor, generated_inputs) + + image_attention_mask = tf.ones( + image_size, + dtype=attention_mask.dtype, + ) + if self.padding == "right": + return tf.concat((attention_mask, image_attention_mask), axis=1) + else: + return tf.concat((image_attention_mask, attention_mask), axis=1) + elif framework == "np": + import numpy as np + + attention_mask = cast(np.ndarray, generated_inputs) + + image_attention_mask = np.ones( + image_size, + dtype=generated_inputs.dtype, + ) + if self.padding == "right": + return np.concatenate((attention_mask, image_attention_mask), axis=1) + else: + return np.concatenate((image_attention_mask, attention_mask), axis=1) + + else: + raise ValueError(f"Input name {input_name} not supported.") diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 43a95881..660fefc9 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -33,13 +33,18 @@ TextEncoderOnnxConfig, TextSeq2SeqOnnxConfig, VisionOnnxConfig, + VLMConfigBehavior, + VLMDecoderOnnxConfig, ) from optimum.exporters.onnx.constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME +from optimum.exporters.onnx.input_generators import Gemma3DummyInputGenerator from optimum.exporters.onnx.model_patcher import ( CLIPModelPatcher, CohereModelPatcher, FluxTransformerModelPatcher, + Gemma3LMModelPatcher, MgpstrModelPatcher, + ModelPatcher, MusicgenModelPatcher, Qwen3MoeModelPatcher, SAMModelPatcher, @@ -142,6 +147,27 @@ "text2text-generation-with-past", ] +COMMON_VLM_TEXT_GENERATION_TASKS = [ + *COMMON_TEXT_GENERATION_TASKS, + "image-text-to-text", + "image-text-to-text-with-past", +] + + +def init_model_configs(): + """Initialize custom model configs on the task manager.""" + # Hacky but works, would be cleaner to expose this behavior in the TasksManager + TasksManager._CUSTOM_CLASSES[("pt", "gemma3", "image-text-to-text")] = ( + "transformers", + "Gemma3ForConditionalGeneration", + ) + TasksManager._CUSTOM_CLASSES[("pt", "gemma3", "image-text-to-text-with-past")] = ( + "transformers", + "Gemma3ForConditionalGeneration", + ) + + +init_model_configs() register_tasks_manager_onnx = TasksManager.create_register("onnx") @@ -507,6 +533,52 @@ class GemmaOnnxConfig(LlamaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.38.0") +@register_tasks_manager_onnx("gemma3_text", *[*COMMON_TEXT_GENERATION_TASKS]) +class Gemma3TextDecoderOnnxConfig(GemmaOnnxConfig): + MIN_TRANSFORMERS_VERSION = version.parse("4.52.0") + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="num_hidden_layers") + _MODEL_PATCHER = Gemma3LMModelPatcher + + def get_model_for_behavior(self, model: PreTrainedModel, behavior: VLMConfigBehavior): + # Unused + _ = behavior + return model.language_model + + @property + def outputs(self) -> dict[str, dict[int, str]]: + output = {"last_hidden_state": {0: "batch_size"}} + if self.use_past: + self.add_past_key_values(output, "outputs") + + return output + + def patch_model_for_export( + self, model: PreTrainedModel, model_kwargs: dict[str, Any] | None = None + ) -> ModelPatcher: + model_kwargs = model_kwargs or {} + model_kwargs["use_cache"] = self.use_past + + return super().patch_model_for_export(model, model_kwargs=model_kwargs) + + +@register_tasks_manager_onnx("gemma3", *COMMON_VLM_TEXT_GENERATION_TASKS) +class Gemma3OnnxConfig(VLMDecoderOnnxConfig): + MIN_TRANSFORMERS_VERSION = version.parse("4.52.0") + DUMMY_INPUT_GENERATOR_CLASSES = ( + Gemma3DummyInputGenerator, + DummyVisionInputGenerator, + GemmaDummyPastKeyValuesGenerator, + ) + DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextAndVisionConfig.with_args( + text_config="text_config", + vision_config="vision_config", + head_dim="text_config.head_dim", + num_key_value_heads="text_config.num_key_value_heads", + allow_new=True, + ) + + @register_tasks_manager_onnx("nemotron", *COMMON_TEXT_GENERATION_TASKS) class NemotronOnnxConfig(GemmaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.48.0") # More stable version than 4.44.0 diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 83bcfb2f..9ff32919 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -18,7 +18,7 @@ import inspect import sys import types -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union import torch import transformers @@ -30,6 +30,7 @@ jit_utils, symbolic_helper, ) +from transformers import PreTrainedModel, TFPreTrainedModel from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from optimum.exporters.onnx._traceable_cache import TraceableCache @@ -1237,6 +1238,158 @@ def __exit__(self, exc_type, exc_value, traceback): Qwen3MoeSparseMoeBlock.forward = self.original_moe_forward +# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1147 +def _gemma3_mm_update_causal_mask( + self, + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training: bool = False, +): + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted + # form and requires no inversion or slicing. + return attention_mask + + min_dtype = torch.finfo(torch.float16).min + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else cache_position[0] + sequence_length + 1 + ) + + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=self.dtype, + device=cache_position.device, + ) + + # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + + # Apply bidirectional mask on images if token type ids are provided + if token_type_ids is not None and sequence_length != 1: + token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2) + token_type_mask[token_type_ids == 0] = False # if text token do not change anything + token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool) + causal_mask = causal_mask.clone() + causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill( + token_type_mask, 0.0 + ) + + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +class Gemma3LMModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: OnnxConfig, + model: Union[PreTrainedModel, TFPreTrainedModel], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + # Difference from original: + # uses Dynamic cache from legacy cache instead of HybridCache + # calculate causal mask from multimodal + + def forward( + self, + attention_mask, + position_ids, + past_key_values, + token_type_ids, + inputs_embeds, + use_cache=True, + ): + from transformers.cache_utils import DynamicCache + + pkv = DynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values[0][0].shape[-2] + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + forward_kwargs = {} + + if is_transformers_version("<", "4.52"): + attention_mask = self._update_causal_mask_mm( + attention_mask, + token_type_ids, + past_key_values, + cache_position, + inputs_embeds, + ) + else: + forward_kwargs["token_type_ids"] = token_type_ids + + result = self.__orig_forward( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=pkv, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **forward_kwargs, + ) + upd_pkv = result["past_key_values"] + result["past_key_values"] = upd_pkv.to_legacy_cache() + return result + + if is_transformers_version("<", "4.53.0"): + model.__orig_forward = model.forward + model.forward = types.MethodType(forward, model) + + super().__init__(config, model, model_kwargs) + + def __enter__(self): + super().__enter__() + + if is_transformers_version("<", "4.52.0"): + self._model._update_causal_mask_mm = types.MethodType(_gemma3_mm_update_causal_mask, self._model) + elif ( + is_transformers_version("<", "4.53.0") + and hasattr(self._model, "model") + and hasattr(self._model.model, "_update_causal_mask") + ): + self._model.model._orig_update_causual_mask = self._model.model._update_causal_mask + self._model.model._update_causal_mask = types.MethodType(_gemma3_mm_update_causal_mask, self._model.model) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + + if is_transformers_version("<", "4.53.0"): + self._model.forward = self._model.__orig_forward + + if is_transformers_version("<", "4.52"): + del self._update_causal_mask_mm + elif ( + is_transformers_version("<", "4.53.0") + and hasattr(self._model, "model") + and hasattr(self._model.model, "_orig_update_causual_mask") + ): + self._model.model._update_causal_mask = self._model.model._orig_update_causual_mask + del self._model.model._orig_update_causual_mask + + # A patched version of diffusers.models.transformers.transformer_flux.apply_rotary_emb # that doesn't reshape the input tensor `x` (which results in a constant shape in the exported ONNX graph) def patched_apply_rotary_emb( @@ -1344,3 +1497,15 @@ def __exit__(self, exc_type, exc_value, traceback): from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding CohereRotaryEmbedding.forward = self.original_forward + + +def patch_update_causal_mask( + model, transformers_version, inner_model_name="model", patch_fn=None, patch_extrnal_model=False +): + if is_transformers_version(">=", transformers_version): + inner_model = getattr(model, inner_model_name, None) if not patch_extrnal_model else model + if inner_model is not None: + if hasattr(inner_model, "_update_causal_mask"): + inner_model._orig_update_causal_mask = inner_model._update_causal_mask + patch_fn = patch_fn or _update_causal_mask_patched + inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model) diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 876ad84e..08cd8c56 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -21,6 +21,8 @@ from packaging import version from transformers.utils import is_torch_available +from optimum.exporters.onnx.constants import VLM_TEXT_GENERATION_MODELS +from optimum.exporters.tasks import TasksManager from optimum.exporters.utils import ( _get_submodels_and_export_configs, ) @@ -75,6 +77,7 @@ "cohere", "falcon", "gemma", + "gemma3", "gpt2", "gpt_bigcode", "gpt_neo", @@ -211,6 +214,23 @@ def _get_submodels_and_onnx_configs( legacy: bool = False, model_kwargs: dict | None = None, ): + if ( + not custom_architecture + and library_name == "transformers" + and model.config.model_type in VLM_TEXT_GENERATION_MODELS + and not monolith + ): + return _get_vlm_submodels_and_onnx_configs( + model=model, + task=task, + custom_onnx_configs=custom_onnx_configs, + library_name=library_name, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + model_kwargs=model_kwargs, + ) + return _get_submodels_and_export_configs( model, task, @@ -232,6 +252,50 @@ def _get_submodels_and_onnx_configs( DEPRECATION_WARNING_GET_MODEL_FOR_EXPORT = "The usage of `optimum.exporters.onnx.utils.get_{model_type}_models_for_export` is deprecated and will be removed in a future release, please use `optimum.exporters.utils.get_{model_type}_models_for_export` instead." +def _get_vlm_submodels_and_onnx_configs( + model: PreTrainedModel, + task: str, + custom_onnx_configs: dict, + library_name: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + preprocessors: list[Any] | None = None, + model_kwargs: dict | None = None, +) -> tuple[ExporterConfig, dict[str, tuple]]: + submodels_and_configs = {} + + main_config_cls = TasksManager.get_exporter_config_constructor( + model=model, task=task, exporter="onnx", library_name=library_name + ) + main_config = main_config_cls( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + if not hasattr(main_config, "get_supported_behaviors"): + message = ( + f"VLM '{model.config.model_type}' does not have a `get_supported_behaviors` " + "method configured in its ONNX config class. Please configure and try again." + ) + raise ValueError(message) + + for behavior in main_config.get_supported_behaviors(task): + submodel_config = main_config.with_behavior(behavior) + submodel = submodel_config.get_model_for_behavior(model, behavior) + submodels_and_configs[behavior] = (submodel, submodel_config) + + # Override config if custom config is provided + for key, custom_onnx_config in custom_onnx_configs.items(): + if key not in submodels_and_configs: + message = f"Invalid custom config key '{key}'. Please use one of {', '.join(submodels_and_configs)}." + raise ValueError(message) + submodel = submodels_and_configs[key][0] + submodels_and_configs[key] = (submodel, custom_onnx_config) + + return main_config, submodels_and_configs + + def get_diffusion_models_for_export( pipeline: DiffusionPipeline, int_dtype: str = "int64", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 268a421c..177107f6 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -204,7 +204,7 @@ def __init__( "To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`." ) - if self.config.model_type in {"gemma", "nemotron"}: + if self.config.model_type in {"gemma", "nemotron", "gemma3"}: self.embed_size_per_head = self.config.head_dim elif self.config.model_type == "gpt_bigcode": self.embed_size_per_head = self.config.hidden_size // self.config.num_attention_heads * 2 @@ -220,6 +220,7 @@ def __init__( "deepseek_v3", "cohere", "gemma", + "gemma3", "helium", "mistral", "llama", diff --git a/optimum/onnxruntime/modeling_vlm.py b/optimum/onnxruntime/modeling_vlm.py new file mode 100644 index 00000000..027fe5a4 --- /dev/null +++ b/optimum/onnxruntime/modeling_vlm.py @@ -0,0 +1,188 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.models import GenerationConfig + +from optimum.onnxruntime.base import ORTSessionMixin +from optimum.onnxruntime.modeling_decoder import ORTModelForCausalLM + +from optimum.onnxruntime.modeling_ort import ORTModel + + +logger = logging.getLogger(__name__) + + +# TODO: to be implemented +class ORTVisionEncoder(ORTSessionMixin): + pass + + +class ORTMultiModalProjector(ORTSessionMixin): + pass + + +class ORTModelForVisualCausalLM(ORTModel): + def __init__( + self, + language_model_with_head, + text_embeddings, + vision_embeddings, + multimodal_projector, + config: PretrainedConfig, + device: str = "CPU", + dynamic_shapes: bool | None = None, + model_save_dir: Union[str, Path, TemporaryDirectory] | None = None, + **kwargs, + ): + if dynamic_shapes is not None: + logger.warning( + f"`dynamic_shapes` was set to {dynamic_shapes}, but this value will be ignored as only dynamic shapes are supported." + ) + + self.is_dynamic = True + self.config = config + self.use_cache = kwargs.get("use_cache", True) + self._model_save_dir = model_save_dir + self._device = device.upper() + + self.preprocessors = kwargs.get("preprocessors", []) + self.vision_embeddings_model = vision_embeddings + self.multimodal_projector = multimodal_projector + self.text_embeddings_model = text_embeddings + self.lm_model = language_model_with_head + + self.language_model = ORTModelForCausalLM( + config=config, + device=device, + ov_config=ov_config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + compile=self._compile_only or enable_compilation, + compile_only=self._compile_only, + ) + self.vision_embeddings = OVVisionEmbedding(self.vision_embeddings_model, self) + + self.main_input_name = "input_ids" + self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config)) + for part in self.additional_parts: + model_part = getattr(self, f"{part}_model", None) + if model_part is not None: + model_part = MODEL_PARTS_CLS_MAPPING[part](model_part, self) + setattr(self, part, model_part) + + if enable_compilation and not self._compile_only: + self.compile() + + # Avoid warnings when creating a transformers pipeline + AutoConfig.register(self.base_model_prefix, AutoConfig) + self.auto_model_class.register(AutoConfig, self.__class__) + + +class _ORTGemma3ForCausalLM(ORTModelForVisualCausalLM): + def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + if input_ids is not None and input_ids.shape[1] == 1: + return None + return self.vision_embeddings(pixel_values).last_hidden_state + + def merge_vision_text_embeddings( + self, + vision_embeds, + inputs_embeds, + input_ids=None, + attention_mask=None, + position_ids=None, + **kwargs, + ): + # Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1323-L1339 + image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds + inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds + if input_ids is None: + special_image_mask = inputs_embeds == torch.from_numpy( + self.get_text_embeddings(torch.tensor([[self.config.image_token_index]], dtype=torch.long))[0] + ) + else: + special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds) + + image_features = image_features.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + return inputs_embeds, attention_mask, position_ids + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if video is not None: + raise ValueError("Video input is not supported") + if audio is not None: + raise ValueError("Audio input is not supported") + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ] + if image is not None: + conversation[0]["content"].insert(0, {"type": "image"}) + + text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + + inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt") + return inputs + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + + # Token type ids used only for first inference mask generation + model_kwargs.pop("token_type_ids", None) + + return model_kwargs + + +MODEL_PARTS_CLS_MAPPING = { + "vision_encoder": ORTVisionEncoder, + "multimodal_projector": ORTMultiModalProjector, + "language_model_with_head": ORTModelForCausalLM, + "vision_language_model": ORTModelForVisualCausalLM, +} +MODEL_TYPE_TO_CLS_MAPPING = { + "gemma3": _ORTGemma3ForCausalLM, +} diff --git a/tests/exporters/onnx/test_export.py b/tests/exporters/onnx/test_export.py index bdb2bbcf..8f6406cd 100644 --- a/tests/exporters/onnx/test_export.py +++ b/tests/exporters/onnx/test_export.py @@ -31,7 +31,13 @@ is_torch_available, ) from transformers.modeling_utils import PreTrainedModel -from transformers.testing_utils import require_onnx, require_torch, require_torch_gpu, require_vision, slow +from transformers.testing_utils import ( + require_onnx, + require_torch, + require_torch_gpu, + require_vision, + slow, +) from optimum.exporters import TasksManager from optimum.exporters.error_utils import AtolError @@ -49,10 +55,15 @@ ) from optimum.exporters.onnx.base import ConfigBehavior from optimum.exporters.onnx.config import TextDecoderOnnxConfig +from optimum.exporters.onnx.constants import VLM_TEXT_GENERATION_MODELS from optimum.exporters.onnx.model_configs import WhisperOnnxConfig from optimum.exporters.onnx.utils import get_speecht5_models_for_export from optimum.onnxruntime.utils import ONNX_DECODER_NAME -from optimum.utils import DummyPastKeyValuesGenerator, DummyTextInputGenerator, NormalizedTextConfig +from optimum.utils import ( + DummyPastKeyValuesGenerator, + DummyTextInputGenerator, + NormalizedTextConfig, +) from optimum.utils.normalized_config import NormalizedConfigManager from optimum.utils.testing_utils import grid_parameters, require_diffusers @@ -134,6 +145,7 @@ def _get_models_to_test(export_models_dict: dict, library_name: str = "transform "text2text-generation", "automatic-speech-recognition", "image-to-text", + "image-tex-to-text", ] ): models_to_test.append( @@ -704,3 +716,72 @@ def test_onnx_seq2seq_model_with_config_with_loss(self): ort_sess.run(output_names, input_feed) gc.collect() + + +class VLMSubmodelExportTestCase(TestCase): + """Test that VLM submodels are correctly exported.""" + + @parameterized.expand( + _get_models_to_test( + {name: model for name, model in PYTORCH_EXPORT_MODELS_TINY.items() if name in VLM_TEXT_GENERATION_MODELS} + ) + ) + @require_torch + def test_correct_submodels_per_task( + self, + test_name: str, + model_type: str, + model_name: str, + task: str, + onnx_config_class_constructor: OnnxConfig, + monolith: bool, + ) -> None: + # Not used + _ = test_name + _ = onnx_config_class_constructor + + with tempfile.TemporaryDirectory() as tmp_dir: + main_export( + model_name_or_path=model_name, + output=tmp_dir, + task=task, + monolith=monolith, + ) + self._validate_submodels_in_directory(dir=tmp_dir, task=task, monolith=monolith, model_type=model_type) + self._check_models_in_directory(dir=tmp_dir) + + def _validate_submodels_in_directory(self, dir: str, task: str, model_type: str, monolith: bool) -> None: + """Validate that the expected submodels are found in the export-target directory.""" + if monolith: + expected_models = {"model.onnx"} + elif task in {"text-generation", "text-generation-with-past"}: + expected_models = {"text_encoder.onnx", "language_model_with_head.onnx"} + elif task in {"image-text-to-text", "image-text-to-text-with-past"}: + expected_models = { + "vision_encoder.onnx", + "multimodal_projector.onnx", + "text_encoder.onnx", + "language_model_with_head.onnx", + } + elif task in {"feature-extraction", "feature-extraction-with-past"}: + expected_models = { + "vision_encoder.onnx", + "multimodal_projector.onnx", + "text_encoder.onnx", + "language_model.onnx", + } + else: + self.fail(f"Task {task} not supported in this test.") + + found_models = {f.name for f in Path(dir).glob("*.onnx")} + self.assertEqual( + found_models, + expected_models, + f"Unexpected submodels found for task {task} and model type {model_type}.", + ) + + def _check_models_in_directory(self, dir: str) -> None: + """Check that the exported models can be loaded in ONNX Runtime.""" + for model_path in Path(dir).glob("*.onnx"): + onnx.load(model_path) + onnx.checker.check_model(model_path) diff --git a/tests/exporters/onnx/utils_tests.py b/tests/exporters/onnx/utils_tests.py index 803f872d..d4eca425 100644 --- a/tests/exporters/onnx/utils_tests.py +++ b/tests/exporters/onnx/utils_tests.py @@ -103,6 +103,7 @@ }, "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForConditionalGeneration", "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", @@ -262,6 +263,7 @@ "encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail", "flaubert": "flaubert/flaubert_small_cased", "gemma": "google/gemma-2b", + "gemma3": "google/gemma-3-4b-it-qat-q4_0-unquantized", "gpt2": "gpt2", "gpt_neo": "EleutherAI/gpt-neo-125M", "gpt_neox": "EleutherAI/gpt-neox-20b", diff --git a/tests/onnxruntime/test_decoder.py b/tests/onnxruntime/test_decoder.py index 85cf4db9..55ea41a3 100644 --- a/tests/onnxruntime/test_decoder.py +++ b/tests/onnxruntime/test_decoder.py @@ -32,6 +32,7 @@ BloomOnnxConfig, CohereOnnxConfig, DeepSeekV3OnnxConfig, + Gemma3OnnxConfig, GemmaOnnxConfig, GraniteOnnxConfig, HeliumOnnxConfig, @@ -108,6 +109,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES.append("qwen2") if is_transformers_version(">=", str(GemmaOnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("gemma") + if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)): + SUPPORTED_ARCHITECTURES.append("gemma3") if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("mpt") if is_transformers_version(">=", str(NemotronOnnxConfig.MIN_TRANSFORMERS_VERSION)): diff --git a/tests/onnxruntime/test_vlm.py b/tests/onnxruntime/test_vlm.py new file mode 100644 index 00000000..94c12e00 --- /dev/null +++ b/tests/onnxruntime/test_vlm.py @@ -0,0 +1,137 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import gc + +import requests +import torch +from parameterized import parameterized +from PIL import Image +from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, set_seed +from transformers.configuration_utils import PretrainedConfig +from transformers.models import GenerationConfig + +from optimum.exporters.onnx.model_patcher import patch_update_causal_mask +from optimum.onnxruntime.modeling_vlm import ( + MODEL_PARTS_CLS_MAPPING, + MODEL_TYPE_TO_CLS_MAPPING, + ORTModelForVisualCausalLM, +) +from optimum.utils import is_transformers_version +from tests.onnxruntime.testing_utils import MODEL_NAMES, ORTModelTestMixin + + +SEED = 42 +TEST_IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg" + + +class ORTModelForImageTextToTextIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = [] + + if is_transformers_version(">", "4.52.0"): + SUPPORTED_ARCHITECTURES += ["gemma3"] + + TASK = "image-text-to-text" + + IMAGE = Image.open( + requests.get( + TEST_IMAGE_URL, + stream=True, + ).raw + ) + + def get_transformer_model_class(self, model_arch): + if is_transformers_version(">=", "4.52.0") and model_arch in [ + "gemma3", + ]: + from transformers import Gemma3ForConditionalGeneration + + return Gemma3ForConditionalGeneration + + return AutoModelForCausalLM + + def get_preprocessors(self, model_arch: str) -> dict: + model_id = MODEL_NAMES[model_arch] + config = AutoConfig.from_pretrained(model_id) + + return { + "processor": AutoProcessor.from_pretrained(model_id), + "tokenizer": None, + "config": config, + } + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch: str) -> None: + prompt = "What is shown in this image?" + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + + transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(model_id).eval() + preprocessors = self.get_preprocessors(model_arch) + ort_model = ORTModelForVisualCausalLM.from_pretrained(model_id, export=True) + self.assertIsInstance(ort_model, MODEL_TYPE_TO_CLS_MAPPING[ort_model.config.model_type]) + for component_name, component in ort_model.components.items(): + self.assertIsInstance(component, MODEL_PARTS_CLS_MAPPING[component_name]) + self.assertIsInstance(ort_model.config, PretrainedConfig) + + inputs = ort_model.preprocess_inputs(**preprocessors, text=prompt, image=self.IMAGE.resize((600, 600))) + transformers_inputs = copy.deepcopy(inputs) + + # Check logits + set_seed(SEED) + ov_outputs = ort_model(**inputs) + set_seed(SEED) + with torch.no_grad(): + transformers_outputs = transformers_model(**transformers_inputs) + self.assertTrue( + torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=4e-3), + f"Max abs diff {(torch.abs(ov_outputs.logits - transformers_outputs.logits).max())}", + ) + + additional_inputs = {} + + # gemma3 does not support dynamic cache, it is unfair to compare dynamic cache result vs hybrid cache, + # align cache representation in torch model + if model_arch == "gemma3": + patch_update_causal_mask( + (transformers_model if is_transformers_version("<", "4.52.0") else transformers_model.language_model), + "4.43.0", + ) + transformers_model._supports_cache_class = True + transformers_model.generation_config.cache_implementation = None + from transformers.cache_utils import DynamicCache + + additional_inputs = {"past_key_values": DynamicCache()} + + # Compare generation + gen_config = GenerationConfig( + max_new_tokens=30, + min_new_tokens=30, + do_sample=False, + eos_token_id=None, + ) + set_seed(SEED) + onnx_outputs = ort_model.generate(**inputs, generation_config=gen_config, **additional_inputs) + set_seed(SEED) + with torch.no_grad(): + torch_outputs = transformers_model.generate( + **transformers_inputs, + generation_config=gen_config, + **additional_inputs, + ) + torch.testing.assert_close(onnx_outputs, torch_outputs, atol=self.ATOL, rtol=self.RTOL) + + del ort_model + del transformers_model + gc.collect() diff --git a/tests/onnxruntime/testing_utils.py b/tests/onnxruntime/testing_utils.py index d3c2958b..ff7e3467 100644 --- a/tests/onnxruntime/testing_utils.py +++ b/tests/onnxruntime/testing_utils.py @@ -68,6 +68,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForConditionalGeneration", "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt_bigcode-multi_query-False": "optimum-internal-testing/tiny-random-gpt_bigcode-multi_query-False",