Skip to content

Commit 4f78a83

Browse files
Add Gemma2 support (#32)
Co-authored-by: Ilyas Moutawwakil <[email protected]> Co-authored-by: IlyasMoutawwakil <[email protected]>
1 parent 2ba7502 commit 4f78a83

File tree

6 files changed

+26
-1
lines changed

6 files changed

+26
-1
lines changed

docs/source/onnx/overview.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
5252
- ESM
5353
- Falcon
5454
- Flaubert
55+
- Gemma
56+
- Gemma 2
5557
- GLM
5658
- GPT-2
5759
- GPT-BigCode

optimum/exporters/onnx/model_configs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,16 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
512512
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
513513

514514

515+
@register_tasks_manager_onnx("gemma2", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
516+
class Gemma2OnnxConfig(TextDecoderOnnxConfig):
517+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
518+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
519+
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
520+
# Gemma 2 was added in transformers v4.42 using HybridCache
521+
# (tuple of past_key_values never supported), DynamicCache since v4.53
522+
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
523+
524+
515525
@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
516526
class GPTOssOnnxConfig(GemmaOnnxConfig):
517527
MIN_TRANSFORMERS_VERSION = version.parse("4.55.0")

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,7 @@ def _from_pretrained(
747747

748748
# Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True
749749
# and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic.
750+
config.use_cache = use_cache
750751
if hasattr(config, "is_decoder"):
751752
config.is_decoder = True
752753
if hasattr(config, "is_encoder_decoder"):
@@ -770,7 +771,8 @@ def _from_pretrained(
770771
generation_config = GenerationConfig.from_model_config(config)
771772

772773
generation_config.use_cache = use_cache
773-
config.use_cache = use_cache
774+
if hasattr(generation_config, "cache_implementation"):
775+
generation_config.cache_implementation = None
774776

775777
if is_transformers_version(">=", "4.45.0"):
776778
misplaced_generation_parameters = config._get_non_default_generation_parameters()

tests/exporters/onnx/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
},
104104
"flaubert": "hf-internal-testing/tiny-random-flaubert",
105105
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
106+
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
106107
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
107108
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
108109
"gpt2": "hf-internal-testing/tiny-random-gpt2",

tests/onnxruntime/test_decoder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
BloomOnnxConfig,
3333
CohereOnnxConfig,
3434
DeepSeekV3OnnxConfig,
35+
Gemma2OnnxConfig,
3536
GemmaOnnxConfig,
3637
GLMOnnxConfig,
3738
GPTOssOnnxConfig,
@@ -113,6 +114,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
113114
SUPPORTED_ARCHITECTURES.append("qwen2")
114115
if is_transformers_version(">=", str(GemmaOnnxConfig.MIN_TRANSFORMERS_VERSION)):
115116
SUPPORTED_ARCHITECTURES.append("gemma")
117+
if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)):
118+
SUPPORTED_ARCHITECTURES.append("gemma2")
116119
if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)):
117120
SUPPORTED_ARCHITECTURES.append("glm")
118121
if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)):
@@ -220,6 +223,12 @@ def test_find_untested_architectures(self):
220223
# So we remove it from the list of supported architectures in the versions before 4.48.0.
221224
supported_architectures.remove("nemotron")
222225

226+
if "gemma2" in supported_architectures and is_transformers_version(
227+
"<", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)
228+
):
229+
# Gemma 2 was added in transformers v4.42 using HybridCache (tuple of past_key_values never supported), DynamicCache since v4.53
230+
supported_architectures.remove("gemma2")
231+
223232
untested_architectures = supported_architectures - tested_architectures
224233

225234
if len(untested_architectures) > 0:

tests/onnxruntime/testing_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
"flaubert": "hf-internal-testing/tiny-random-flaubert",
6969
"flux": "optimum-internal-testing/tiny-random-flux",
7070
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
71+
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
7172
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
7273
"gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel",
7374
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",

0 commit comments

Comments
 (0)