Skip to content

Commit 68f201a

Browse files
Add Nemotron Support (#36)
Add Nemotron Support
1 parent 5655b4d commit 68f201a

File tree

7 files changed

+24
-1
lines changed

7 files changed

+24
-1
lines changed

docs/source/onnx/overview.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
8686
- MPNet
8787
- MT5
8888
- Musicgen (text-conditional only)
89+
- Nemotron
8990
- Nystromformer
9091
- OLMo
9192
- OLMo2

optimum/exporters/onnx/model_configs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,12 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
507507
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
508508

509509

510+
@register_tasks_manager_onnx("nemotron", *COMMON_TEXT_GENERATION_TASKS)
511+
class NemotronOnnxConfig(GemmaOnnxConfig):
512+
MIN_TRANSFORMERS_VERSION = version.parse("4.48.0") # More stable version than 4.44.0
513+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfigWithGQA
514+
515+
510516
@register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS)
511517
class GraniteOnnxConfig(LlamaOnnxConfig):
512518
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")

optimum/exporters/onnx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
"internlm2",
8686
"llama",
8787
"mistral",
88+
"nemotron",
8889
"phi",
8990
"phi3",
9091
"qwen2",

optimum/onnxruntime/modeling_decoder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(
204204
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
205205
)
206206

207-
if self.config.model_type == "gemma":
207+
if self.config.model_type in {"gemma", "nemotron"}:
208208
self.embed_size_per_head = self.config.head_dim
209209
elif self.config.model_type == "gpt_bigcode":
210210
self.embed_size_per_head = self.config.hidden_size // self.config.num_attention_heads * 2
@@ -223,6 +223,7 @@ def __init__(
223223
"helium",
224224
"mistral",
225225
"llama",
226+
"nemotron",
226227
"qwen2",
227228
"qwen3",
228229
"qwen3_moe",

tests/exporters/onnx/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
143143
"mt5": "lewtun/tiny-random-mt5",
144144
"musicgen": "hf-internal-testing/tiny-random-MusicgenForConditionalGeneration",
145+
"nemotron": "badaoui/tiny-random-NemotronForCausalLM",
145146
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
146147
"olmo": "hf-internal-testing/tiny-random-OlmoForCausalLM",
147148
"olmo2": "hf-internal-testing/tiny-random-Olmo2ForCausalLM",

tests/onnxruntime/test_decoder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
HeliumOnnxConfig,
3838
InternLM2OnnxConfig,
3939
MPTOnnxConfig,
40+
NemotronOnnxConfig,
4041
Olmo2OnnxConfig,
4142
OlmoOnnxConfig,
4243
OPTOnnxConfig,
@@ -109,6 +110,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
109110
SUPPORTED_ARCHITECTURES.append("gemma")
110111
if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)):
111112
SUPPORTED_ARCHITECTURES.append("mpt")
113+
if is_transformers_version(">=", str(NemotronOnnxConfig.MIN_TRANSFORMERS_VERSION)):
114+
SUPPORTED_ARCHITECTURES.append("nemotron")
112115
if is_transformers_version(">=", str(GraniteOnnxConfig.MIN_TRANSFORMERS_VERSION)):
113116
SUPPORTED_ARCHITECTURES.append("granite")
114117
if is_transformers_version(">=", str(HeliumOnnxConfig.MIN_TRANSFORMERS_VERSION)):
@@ -199,6 +202,15 @@ def test_find_untested_architectures(self):
199202
transformers_architectures = set(CONFIG_MAPPING_NAMES.keys())
200203
onnx_architectures = set(TasksManager.get_supported_model_type_for_task(task=self.TASK, exporter="onnx"))
201204
supported_architectures = onnx_architectures & transformers_architectures
205+
206+
if "nemotron" in supported_architectures and is_transformers_version(
207+
"<=", str(NemotronOnnxConfig.MIN_TRANSFORMERS_VERSION)
208+
):
209+
# Nemotron was introduced in Transformers 4.44.0, but it has some issues. Specifically, it did not properly handle legacy cache formats (Lists/Cache), and it also did not return past_key_values when use_cache=True.
210+
# We are using its 4.48.0 version, which is more stable.
211+
# So we remove it from the list of supported architectures in the versions before 4.48.0.
212+
supported_architectures.remove("nemotron")
213+
202214
untested_architectures = supported_architectures - tested_architectures
203215

204216
if len(untested_architectures) > 0:

tests/onnxruntime/testing_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
100100
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
101101
"mt5": "lewtun/tiny-random-mt5",
102+
"nemotron": "badaoui/tiny-random-NemotronForCausalLM",
102103
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
103104
"olmo": "katuni4ka/tiny-random-olmo-hf",
104105
"olmo2": "hf-internal-testing/tiny-random-Olmo2ForCausalLM",

0 commit comments

Comments
 (0)