File tree Expand file tree Collapse file tree 6 files changed +26
-1
lines changed Expand file tree Collapse file tree 6 files changed +26
-1
lines changed Original file line number Diff line number Diff line change @@ -52,6 +52,8 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
5252- ESM
5353- Falcon
5454- Flaubert
55+ - Gemma
56+ - Gemma 2
5557- GLM
5658- GPT-2
5759- GPT-BigCode
Original file line number Diff line number Diff line change @@ -512,6 +512,16 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
512512 MIN_TRANSFORMERS_VERSION = version .parse ("4.38.0" )
513513
514514
515+ @register_tasks_manager_onnx ("gemma2" , * [* COMMON_TEXT_GENERATION_TASKS , "text-classification" ])
516+ class Gemma2OnnxConfig (TextDecoderOnnxConfig ):
517+ NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
518+ DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , GemmaDummyPastKeyValuesGenerator )
519+ DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
520+ # Gemma 2 was added in transformers v4.42 using HybridCache
521+ # (tuple of past_key_values never supported), DynamicCache since v4.53
522+ MIN_TRANSFORMERS_VERSION = version .parse ("4.53.0" )
523+
524+
515525@register_tasks_manager_onnx ("gpt_oss" , * COMMON_TEXT_GENERATION_TASKS )
516526class GPTOssOnnxConfig (GemmaOnnxConfig ):
517527 MIN_TRANSFORMERS_VERSION = version .parse ("4.55.0" )
Original file line number Diff line number Diff line change @@ -747,6 +747,7 @@ def _from_pretrained(
747747
748748 # Important: for encoder-decoder models used with CausalLM, we need to set the is_decoder flag to True
749749 # and the is_encoder_decoder flag to False. This is needed for the model to work correctly with generation logic.
750+ config .use_cache = use_cache
750751 if hasattr (config , "is_decoder" ):
751752 config .is_decoder = True
752753 if hasattr (config , "is_encoder_decoder" ):
@@ -770,7 +771,8 @@ def _from_pretrained(
770771 generation_config = GenerationConfig .from_model_config (config )
771772
772773 generation_config .use_cache = use_cache
773- config .use_cache = use_cache
774+ if hasattr (generation_config , "cache_implementation" ):
775+ generation_config .cache_implementation = None
774776
775777 if is_transformers_version (">=" , "4.45.0" ):
776778 misplaced_generation_parameters = config ._get_non_default_generation_parameters ()
Original file line number Diff line number Diff line change 103103 },
104104 "flaubert" : "hf-internal-testing/tiny-random-flaubert" ,
105105 "gemma" : "fxmarty/tiny-random-GemmaForCausalLM" ,
106+ "gemma2" : "hf-internal-testing/tiny-random-Gemma2ForCausalLM" ,
106107 "glm" : "hf-internal-testing/tiny-random-GlmForCausalLM" ,
107108 "glpn" : "hf-internal-testing/tiny-random-GLPNModel" ,
108109 "gpt2" : "hf-internal-testing/tiny-random-gpt2" ,
Original file line number Diff line number Diff line change 3232 BloomOnnxConfig ,
3333 CohereOnnxConfig ,
3434 DeepSeekV3OnnxConfig ,
35+ Gemma2OnnxConfig ,
3536 GemmaOnnxConfig ,
3637 GLMOnnxConfig ,
3738 GPTOssOnnxConfig ,
@@ -113,6 +114,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
113114 SUPPORTED_ARCHITECTURES .append ("qwen2" )
114115 if is_transformers_version (">=" , str (GemmaOnnxConfig .MIN_TRANSFORMERS_VERSION )):
115116 SUPPORTED_ARCHITECTURES .append ("gemma" )
117+ if is_transformers_version (">=" , str (Gemma2OnnxConfig .MIN_TRANSFORMERS_VERSION )):
118+ SUPPORTED_ARCHITECTURES .append ("gemma2" )
116119 if is_transformers_version (">=" , str (GLMOnnxConfig .MIN_TRANSFORMERS_VERSION )):
117120 SUPPORTED_ARCHITECTURES .append ("glm" )
118121 if is_transformers_version (">=" , str (MPTOnnxConfig .MIN_TRANSFORMERS_VERSION )):
@@ -220,6 +223,12 @@ def test_find_untested_architectures(self):
220223 # So we remove it from the list of supported architectures in the versions before 4.48.0.
221224 supported_architectures .remove ("nemotron" )
222225
226+ if "gemma2" in supported_architectures and is_transformers_version (
227+ "<" , str (Gemma2OnnxConfig .MIN_TRANSFORMERS_VERSION )
228+ ):
229+ # Gemma 2 was added in transformers v4.42 using HybridCache (tuple of past_key_values never supported), DynamicCache since v4.53
230+ supported_architectures .remove ("gemma2" )
231+
223232 untested_architectures = supported_architectures - tested_architectures
224233
225234 if len (untested_architectures ) > 0 :
Original file line number Diff line number Diff line change 6868 "flaubert" : "hf-internal-testing/tiny-random-flaubert" ,
6969 "flux" : "optimum-internal-testing/tiny-random-flux" ,
7070 "gemma" : "fxmarty/tiny-random-GemmaForCausalLM" ,
71+ "gemma2" : "hf-internal-testing/tiny-random-Gemma2ForCausalLM" ,
7172 "glm" : "hf-internal-testing/tiny-random-GlmForCausalLM" ,
7273 "gpt2" : "hf-internal-testing/tiny-random-GPT2LMHeadModel" ,
7374 "gpt_bigcode" : "hf-internal-testing/tiny-random-GPTBigCodeModel" ,
You can’t perform that action at this time.
0 commit comments