Skip to content

Commit e34c05c

Browse files
ananthsubeagle705
authored andcommitted
Nemotron model provider + bridge (NVIDIA-NeMo#485)
* nemotron model provider Signed-off-by: Ananth Subramaniam <[email protected]> * nemotron bridge Signed-off-by: Ananth Subramaniam <[email protected]> * add specific providers Signed-off-by: Ananth Subramaniam <[email protected]> * update imports and rebase Signed-off-by: Ananth Subramaniam <[email protected]> --------- Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 85707ce commit e34c05c

File tree

8 files changed

+868
-0
lines changed

8 files changed

+868
-0
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@
7373
NemotronNano9Bv2Provider,
7474
NemotronNano12Bv2Provider,
7575
)
76+
from megatron.bridge.models.nemotron import (
77+
Nemotron3ModelProvider4B,
78+
Nemotron3ModelProvider8B,
79+
Nemotron3ModelProvider22B,
80+
Nemotron4ModelProvider15B,
81+
Nemotron4ModelProvider340B,
82+
NemotronBridge,
83+
NemotronModelProvider,
84+
)
7685
from megatron.bridge.models.qwen import (
7786
Qwen2ModelProvider,
7887
Qwen2ModelProvider1P5B,
@@ -184,6 +193,14 @@
184193
"NVIDIAMambaProvider8B",
185194
"MistralModelProvider",
186195
"MistralSmall3ModelProvider24B",
196+
# Nemotron Models
197+
"NemotronBridge",
198+
"NemotronModelProvider",
199+
"Nemotron3ModelProvider4B",
200+
"Nemotron3ModelProvider8B",
201+
"Nemotron3ModelProvider22B",
202+
"Nemotron4ModelProvider15B",
203+
"Nemotron4ModelProvider340B",
187204
# VL Models
188205
"Qwen25VLModel",
189206
"Qwen25VLBridge",
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from megatron.bridge.models.nemotron.nemotron_bridge import NemotronBridge
16+
from megatron.bridge.models.nemotron.nemotron_provider import (
17+
Nemotron3ModelProvider4B,
18+
Nemotron3ModelProvider8B,
19+
Nemotron3ModelProvider22B,
20+
Nemotron4ModelProvider15B,
21+
Nemotron4ModelProvider340B,
22+
NemotronModelProvider,
23+
)
24+
25+
26+
__all__ = [
27+
"NemotronBridge",
28+
"NemotronModelProvider",
29+
"Nemotron3ModelProvider4B",
30+
"Nemotron3ModelProvider8B",
31+
"Nemotron3ModelProvider22B",
32+
"Nemotron4ModelProvider15B",
33+
"Nemotron4ModelProvider340B",
34+
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from megatron.core.models.gpt.gpt_model import GPTModel
17+
from transformers import NemotronForCausalLM
18+
19+
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
20+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
21+
from megatron.bridge.models.conversion.param_mapping import (
22+
AutoMapping,
23+
QKVMapping,
24+
)
25+
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
26+
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider
27+
28+
29+
@MegatronModelBridge.register_bridge(source=NemotronForCausalLM, target=GPTModel)
30+
class NemotronBridge(MegatronModelBridge):
31+
"""
32+
Megatron Bridge for Nemotron Causal LM.
33+
34+
As a user you would not use this bridge directly, but through `AutoBridge`.
35+
36+
Example:
37+
>>> from megatron.bridge import AutoBridge
38+
>>> bridge = AutoBridge.from_hf_pretrained("nvidia/Nemotron-4-340B-Instruct")
39+
>>> provider = bridge.to_megatron_provider()
40+
"""
41+
42+
def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronModelProvider:
43+
hf_config = hf_pretrained.config
44+
45+
provider = NemotronModelProvider(
46+
num_layers=hf_config.num_hidden_layers,
47+
hidden_size=hf_config.hidden_size,
48+
ffn_hidden_size=hf_config.intermediate_size,
49+
num_attention_heads=hf_config.num_attention_heads,
50+
init_method_std=hf_config.initializer_range,
51+
layernorm_epsilon=hf_config.norm_eps,
52+
num_query_groups=hf_config.num_key_value_heads,
53+
seq_length=hf_config.max_position_embeddings,
54+
rotary_base=hf_config.rope_theta,
55+
rotary_percent=hf_config.partial_rotary_factor,
56+
kv_channels=getattr(hf_config, "head_dim", None),
57+
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size),
58+
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
59+
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
60+
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
61+
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
62+
generation_config=hf_pretrained.generation_config,
63+
vocab_size=hf_config.vocab_size,
64+
)
65+
66+
return provider
67+
68+
def mapping_registry(self) -> MegatronMappingRegistry:
69+
# Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format
70+
# First create simple 1:1 parameter mappings using a dictionary for readability
71+
72+
# Dictionary maps Megatron parameter names -> HF parameter names
73+
# Supports wildcard (*) patterns for layer-specific parameters
74+
param_mappings = {
75+
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
76+
"output_layer.weight": "lm_head.weight",
77+
"decoder.final_layernorm.weight": "model.norm.weight",
78+
"decoder.final_layernorm.bias": "model.norm.bias",
79+
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
80+
"decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias",
81+
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
82+
"decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias",
83+
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
84+
"decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.up_proj.weight",
85+
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
86+
}
87+
88+
mapping_list = []
89+
# Convert each dictionary entry to AutoMapping(megatron_param, hf_param)
90+
for megatron_param, hf_param in param_mappings.items():
91+
mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param))
92+
93+
# Add special mappings that require parameter concatenation/transformation
94+
mapping_list.extend(
95+
[
96+
# QKV: Combine separate Q, K, V matrices into single QKV matrix
97+
QKVMapping(
98+
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
99+
q="model.layers.*.self_attn.q_proj.weight",
100+
k="model.layers.*.self_attn.k_proj.weight",
101+
v="model.layers.*.self_attn.v_proj.weight",
102+
),
103+
]
104+
)
105+
106+
return MegatronMappingRegistry(*mapping_list)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from dataclasses import dataclass, field
17+
from typing import Callable, Optional
18+
19+
import torch
20+
21+
from megatron.bridge.models.gpt_provider import GPTModelProvider
22+
from megatron.bridge.utils import fusions
23+
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
def squared_relu(x):
29+
"""Squared ReLU activation function."""
30+
return torch.pow(torch.nn.functional.relu(x), 2)
31+
32+
33+
@dataclass
34+
class NemotronModelProvider(GPTModelProvider):
35+
"""Configuration class for Nemotron models."""
36+
37+
# configs that are common across model sizes
38+
normalization: str = "LayerNorm"
39+
activation_func: Callable = squared_relu
40+
position_embedding_type: str = "rope"
41+
share_embeddings_and_output_weights: bool = False
42+
add_bias_linear: bool = False
43+
44+
hidden_dropout: float = 0.0
45+
attention_dropout: float = 0.0
46+
rotary_percent: float = 0.5
47+
masked_softmax_fusion: bool = field(default_factory=fusions.can_enable_masked_softmax_fusion)
48+
persist_layer_norm: bool = True
49+
bias_dropout_add_fusion: bool = False
50+
layernorm_zero_centered_gamma: bool = True
51+
cross_entropy_loss_fusion: bool = True
52+
apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion)
53+
54+
# Nemotron3Config4B as default configs
55+
num_layers: int = 32
56+
seq_length: int = 4096
57+
hidden_size: int = 3072
58+
ffn_hidden_size: int = 9216
59+
num_attention_heads: int = 24
60+
num_query_groups: Optional[int] = 8
61+
kv_channels: Optional[int] = 128
62+
init_method_std: float = 0.0134
63+
64+
# Data type settings to match HF models
65+
bf16: bool = True
66+
fp16: bool = False
67+
params_dtype: torch.dtype = torch.bfloat16
68+
autocast_dtype: torch.dtype = torch.bfloat16
69+
70+
71+
@dataclass
72+
class Nemotron3ModelProvider4B(NemotronModelProvider):
73+
"""
74+
Configuration class for the Nemotron3 4B model, inheriting from NemotronModelProvider.
75+
"""
76+
77+
num_layers: int = 32
78+
seq_length: int = 4096
79+
hidden_size: int = 3072
80+
ffn_hidden_size: int = 9216
81+
num_attention_heads: int = 24
82+
num_query_groups: int = 8
83+
kv_channels: Optional[int] = 128
84+
init_method_std: float = 0.0134
85+
86+
87+
@dataclass
88+
class Nemotron3ModelProvider8B(NemotronModelProvider):
89+
"""
90+
Configuration class for the Nemotron3 8B model, inheriting from NemotronModelProvider.
91+
"""
92+
93+
num_layers: int = 32
94+
seq_length: int = 4096
95+
hidden_size: int = 4096
96+
ffn_hidden_size: int = 16384
97+
num_attention_heads: int = 32
98+
num_query_groups: Optional[int] = None
99+
kv_channels: Optional[int] = None
100+
init_method_std: float = 0.010
101+
102+
103+
@dataclass
104+
class Nemotron3ModelProvider22B(NemotronModelProvider):
105+
"""
106+
Configuration class for the Nemotron3 22B model, inheriting from NemotronModelProvider.
107+
"""
108+
109+
num_layers: int = 40
110+
seq_length: int = 4096
111+
hidden_size: int = 6144
112+
ffn_hidden_size: int = 24576
113+
num_attention_heads: int = 48
114+
num_query_groups: Optional[int] = None
115+
kv_channels: Optional[int] = None
116+
init_method_std: float = 0.008
117+
118+
119+
@dataclass
120+
class Nemotron4ModelProvider15B(NemotronModelProvider):
121+
"""
122+
Configuration class for the Nemotron4 15B model, inheriting from NemotronModelProvider.
123+
"""
124+
125+
num_layers: int = 32
126+
seq_length: int = 4096
127+
hidden_size: int = 6144
128+
ffn_hidden_size: int = 24576
129+
num_attention_heads: int = 48
130+
num_query_groups: Optional[int] = 8
131+
kv_channels: Optional[int] = None
132+
init_method_std: float = 0.0134
133+
134+
135+
@dataclass
136+
class Nemotron4ModelProvider340B(NemotronModelProvider):
137+
"""
138+
Configuration class for the Nemotron4 340B model, inheriting from NemotronModelProvider.
139+
"""
140+
141+
num_layers: int = 96
142+
seq_length: int = 4096
143+
hidden_size: int = 18432
144+
ffn_hidden_size: int = 73728
145+
num_attention_heads: int = 96
146+
num_query_groups: Optional[int] = 8
147+
kv_channels: Optional[int] = None
148+
init_method_std: float = 0.0063

0 commit comments

Comments
 (0)