Skip to content

Commit 0e0cc69

Browse files
committed
add specific providers
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 4e080b4 commit 0e0cc69

File tree

5 files changed

+267
-3
lines changed

5 files changed

+267
-3
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,11 @@
186186
# Nemotron Models
187187
"NemotronBridge",
188188
"NemotronModelProvider",
189+
"Nemotron3ModelProvider4B",
190+
"Nemotron3ModelProvider8B",
191+
"Nemotron3ModelProvider22B",
192+
"Nemotron4ModelProvider15B",
193+
"Nemotron4ModelProvider340B",
189194
# VL Models
190195
"Qwen25VLModel",
191196
"Qwen25VLBridge",

src/megatron/bridge/models/nemotron/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,22 @@
1313
# limitations under the License.
1414

1515
from megatron.bridge.models.nemotron.nemotron_bridge import NemotronBridge
16-
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider
16+
from megatron.bridge.models.nemotron.nemotron_provider import (
17+
Nemotron3ModelProvider4B,
18+
Nemotron3ModelProvider8B,
19+
Nemotron3ModelProvider22B,
20+
Nemotron4ModelProvider15B,
21+
Nemotron4ModelProvider340B,
22+
NemotronModelProvider,
23+
)
1724

1825

1926
__all__ = [
2027
"NemotronBridge",
2128
"NemotronModelProvider",
29+
"Nemotron3ModelProvider4B",
30+
"Nemotron3ModelProvider8B",
31+
"Nemotron3ModelProvider22B",
32+
"Nemotron4ModelProvider15B",
33+
"Nemotron4ModelProvider340B",
2234
]

src/megatron/bridge/models/nemotron/nemotron_provider.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from dataclasses import dataclass, field
17-
from typing import Callable
17+
from typing import Callable, Optional
1818

1919
import torch
2020

@@ -50,3 +50,101 @@ class NemotronModelProvider(GPTModelProvider):
5050
layernorm_zero_centered_gamma: bool = True
5151
cross_entropy_loss_fusion: bool = True
5252
apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion)
53+
54+
# Nemotron3Config4B as default configs
55+
num_layers: int = 32
56+
seq_length: int = 4096
57+
hidden_size: int = 3072
58+
ffn_hidden_size: int = 9216
59+
num_attention_heads: int = 24
60+
num_query_groups: Optional[int] = 8
61+
kv_channels: Optional[int] = 128
62+
init_method_std: float = 0.0134
63+
64+
# Data type settings to match HF models
65+
bf16: bool = True
66+
fp16: bool = False
67+
params_dtype: torch.dtype = torch.bfloat16
68+
autocast_dtype: torch.dtype = torch.bfloat16
69+
70+
71+
@dataclass
72+
class Nemotron3ModelProvider4B(NemotronModelProvider):
73+
"""
74+
Configuration class for the Nemotron3 4B model, inheriting from NemotronModelProvider.
75+
Maps to: nvidia/Minitron-4B-Base, nvidia/Nemotron-Mini-4B-Instruct
76+
"""
77+
78+
num_layers: int = 32
79+
seq_length: int = 4096
80+
hidden_size: int = 3072
81+
ffn_hidden_size: int = 9216
82+
num_attention_heads: int = 24
83+
num_query_groups: int = 8
84+
kv_channels: Optional[int] = 128
85+
init_method_std: float = 0.0134
86+
87+
88+
@dataclass
89+
class Nemotron3ModelProvider8B(NemotronModelProvider):
90+
"""
91+
Configuration class for the Nemotron3 8B model, inheriting from NemotronModelProvider.
92+
Maps to: nvidia/Minitron-8B-Base
93+
"""
94+
95+
num_layers: int = 32
96+
seq_length: int = 4096
97+
hidden_size: int = 4096
98+
ffn_hidden_size: int = 16384
99+
num_attention_heads: int = 48 # Updated to match HF model (was 32)
100+
num_query_groups: int = 8 # Updated to match HF model (was None)
101+
kv_channels: Optional[int] = 128 # Updated to match HF model (was None)
102+
init_method_std: float = 0.010
103+
104+
105+
@dataclass
106+
class Nemotron3ModelProvider22B(NemotronModelProvider):
107+
"""
108+
Configuration class for the Nemotron3 22B model, inheriting from NemotronModelProvider.
109+
"""
110+
111+
num_layers: int = 40
112+
seq_length: int = 4096
113+
hidden_size: int = 6144
114+
ffn_hidden_size: int = 24576
115+
num_attention_heads: int = 48
116+
num_query_groups: Optional[int] = None
117+
kv_channels: Optional[int] = None
118+
init_method_std: float = 0.008
119+
120+
121+
@dataclass
122+
class Nemotron4ModelProvider15B(NemotronModelProvider):
123+
"""
124+
Configuration class for the Nemotron4 15B model, inheriting from NemotronModelProvider.
125+
"""
126+
127+
num_layers: int = 32
128+
seq_length: int = 4096
129+
hidden_size: int = 6144
130+
ffn_hidden_size: int = 24576
131+
num_attention_heads: int = 48
132+
num_query_groups: Optional[int] = 8
133+
kv_channels: Optional[int] = None
134+
init_method_std: float = 0.0134
135+
136+
137+
@dataclass
138+
class Nemotron4ModelProvider340B(NemotronModelProvider):
139+
"""
140+
Configuration class for the Nemotron4 340B model, inheriting from NemotronModelProvider.
141+
"""
142+
143+
num_layers: int = 96
144+
seq_length: int = 4096
145+
hidden_size: int = 18432
146+
ffn_hidden_size: int = 73728
147+
num_attention_heads: int = 96
148+
num_query_groups: Optional[int] = 8
149+
kv_channels: Optional[int] = None
150+
init_method_std: float = 0.0063
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from megatron.bridge.models.conversion.auto_bridge import AutoBridge
18+
from megatron.bridge.models.nemotron import (
19+
Nemotron3ModelProvider4B,
20+
Nemotron3ModelProvider8B,
21+
)
22+
from tests.functional_tests.utils import compare_provider_configs
23+
24+
25+
HF_MODEL_ID_TO_BRIDGE_MODEL_PROVIDER = {
26+
"nvidia/Minitron-4B-Base": Nemotron3ModelProvider4B,
27+
"nvidia/Minitron-8B-Base": Nemotron3ModelProvider8B,
28+
"nvidia/Nemotron-Mini-4B-Instruct": Nemotron3ModelProvider4B,
29+
}
30+
31+
32+
class TestNemotronModelProviderMapping:
33+
"""Test that bridge provider configs are equivalent to predefined provider configs."""
34+
35+
@pytest.mark.parametrize("hf_model_id,provider_class", list(HF_MODEL_ID_TO_BRIDGE_MODEL_PROVIDER.items()))
36+
def test_bridge_vs_predefined_provider_config_equivalence(self, hf_model_id, provider_class):
37+
"""Test that bridge converted provider config matches predefined provider config."""
38+
# Create bridge from HF model
39+
bridge = AutoBridge.from_hf_pretrained(hf_model_id)
40+
converted_provider = bridge.to_megatron_provider(load_weights=False)
41+
converted_provider.finalize()
42+
43+
# Create predefined provider
44+
predefined_provider = provider_class()
45+
predefined_provider.finalize()
46+
47+
# Compare configs
48+
compare_provider_configs(converted_provider, predefined_provider, hf_model_id)

tests/unit_tests/models/nemotron/test_nemotron_bridge.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
2222
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
2323
from megatron.bridge.models.nemotron.nemotron_bridge import NemotronBridge
24-
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider
24+
from megatron.bridge.models.nemotron.nemotron_provider import (
25+
Nemotron3ModelProvider4B,
26+
Nemotron3ModelProvider8B,
27+
Nemotron3ModelProvider22B,
28+
Nemotron4ModelProvider15B,
29+
Nemotron4ModelProvider340B,
30+
NemotronModelProvider,
31+
)
2532

2633

2734
class TestNemotronBridge:
@@ -132,3 +139,97 @@ def test_dtype_configuration(self, mock_pretrained_nemotron):
132139
assert provider.bf16 is True
133140
assert provider.fp16 is False
134141
assert provider.params_dtype == torch.bfloat16
142+
143+
144+
class TestNemotronSpecificProviders:
145+
"""Test cases for specific Nemotron model provider configurations."""
146+
147+
def test_nemotron3_4b_config(self):
148+
"""Test Nemotron3 4B provider configuration matches HF model specs."""
149+
provider = Nemotron3ModelProvider4B()
150+
151+
# Should match nvidia/Minitron-4B-Base and nvidia/Nemotron-Mini-4B-Instruct
152+
assert provider.hidden_size == 3072
153+
assert provider.num_layers == 32
154+
assert provider.num_attention_heads == 24
155+
assert provider.num_query_groups == 8
156+
assert provider.ffn_hidden_size == 9216
157+
assert provider.kv_channels == 128
158+
assert provider.seq_length == 4096
159+
assert provider.init_method_std == 0.0134
160+
161+
def test_nemotron3_8b_config(self):
162+
"""Test Nemotron3 8B provider configuration matches HF model specs."""
163+
provider = Nemotron3ModelProvider8B()
164+
165+
# Should match nvidia/Minitron-8B-Base
166+
assert provider.hidden_size == 4096
167+
assert provider.num_layers == 32
168+
assert provider.num_attention_heads == 48 # Updated to match HF
169+
assert provider.num_query_groups == 8 # Updated to match HF
170+
assert provider.ffn_hidden_size == 16384
171+
assert provider.kv_channels == 128 # Updated to match HF
172+
assert provider.seq_length == 4096
173+
assert provider.init_method_std == 0.010
174+
175+
def test_nemotron3_22b_config(self):
176+
"""Test Nemotron3 22B provider configuration."""
177+
provider = Nemotron3ModelProvider22B()
178+
179+
assert provider.hidden_size == 6144
180+
assert provider.num_layers == 40
181+
assert provider.num_attention_heads == 48
182+
assert provider.num_query_groups is None
183+
assert provider.ffn_hidden_size == 24576
184+
assert provider.kv_channels is None
185+
assert provider.seq_length == 4096
186+
assert provider.init_method_std == 0.008
187+
188+
def test_nemotron4_15b_config(self):
189+
"""Test Nemotron4 15B provider configuration."""
190+
provider = Nemotron4ModelProvider15B()
191+
192+
assert provider.hidden_size == 6144
193+
assert provider.num_layers == 32
194+
assert provider.num_attention_heads == 48
195+
assert provider.num_query_groups == 8
196+
assert provider.ffn_hidden_size == 24576
197+
assert provider.kv_channels is None
198+
assert provider.seq_length == 4096
199+
assert provider.init_method_std == 0.0134
200+
201+
def test_nemotron4_340b_config(self):
202+
"""Test Nemotron4 340B provider configuration."""
203+
provider = Nemotron4ModelProvider340B()
204+
205+
# Should match nvidia/Nemotron-4-340B-Base/Instruct (if available)
206+
assert provider.hidden_size == 18432
207+
assert provider.num_layers == 96
208+
assert provider.num_attention_heads == 96
209+
assert provider.num_query_groups == 8
210+
assert provider.ffn_hidden_size == 73728
211+
assert provider.kv_channels is None
212+
assert provider.seq_length == 4096
213+
assert provider.init_method_std == 0.0063
214+
215+
def test_all_providers_have_nemotron_defaults(self):
216+
"""Test that all specific providers inherit Nemotron-specific defaults."""
217+
providers = [
218+
Nemotron3ModelProvider4B(),
219+
Nemotron3ModelProvider8B(),
220+
Nemotron3ModelProvider22B(),
221+
Nemotron4ModelProvider15B(),
222+
Nemotron4ModelProvider340B(),
223+
]
224+
225+
for provider in providers:
226+
# Check Nemotron-specific defaults
227+
assert provider.normalization == "LayerNorm"
228+
assert provider.position_embedding_type == "rope"
229+
assert provider.share_embeddings_and_output_weights is False
230+
assert provider.add_bias_linear is False
231+
assert provider.hidden_dropout == 0.0
232+
assert provider.attention_dropout == 0.0
233+
assert provider.rotary_percent == 0.5
234+
assert provider.layernorm_zero_centered_gamma is True
235+
assert provider.cross_entropy_loss_fusion is True

0 commit comments

Comments
 (0)