Skip to content

Commit 4e080b4

Browse files
committed
nemotron bridge
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent af5bc9a commit 4e080b4

File tree

7 files changed

+525
-281
lines changed

7 files changed

+525
-281
lines changed

src/megatron/bridge/models/__init__.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@
5454
Llama32ModelProvider3B,
5555
LlamaModelProvider,
5656
)
57-
<<<<<<< HEAD
5857
from megatron.bridge.models.mamba.mamba_provider import (
5958
MambaProvider,
6059
MambaProvider1_3B,
@@ -73,15 +72,10 @@
7372
NemotronHModelProvider,
7473
NemotronNano9Bv2Provider,
7574
NemotronNano12Bv2Provider,
76-
=======
75+
)
7776
from megatron.bridge.models.nemotron import (
78-
Nemotron3ModelProvider4B,
79-
Nemotron3ModelProvider8B,
80-
Nemotron3ModelProvider22B,
81-
Nemotron4ModelProvider15B,
82-
Nemotron4ModelProvider340B,
77+
NemotronBridge,
8378
NemotronModelProvider,
84-
>>>>>>> c3d509cf (nemotron model provider)
8579
)
8680
from megatron.bridge.models.qwen import (
8781
Qwen2ModelProvider,
@@ -189,6 +183,9 @@
189183
"MambaProvider780M",
190184
"NVIDIAMambaHybridProvider8B",
191185
"NVIDIAMambaProvider8B",
186+
# Nemotron Models
187+
"NemotronBridge",
188+
"NemotronModelProvider",
192189
# VL Models
193190
"Qwen25VLModel",
194191
"Qwen25VLBridge",

src/megatron/bridge/models/nemotron/__init__.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from megatron.bridge.models.nemotron.nemotron_provider import (
16-
Nemotron3ModelProvider4B,
17-
Nemotron3ModelProvider8B,
18-
Nemotron3ModelProvider22B,
19-
Nemotron4ModelProvider15B,
20-
Nemotron4ModelProvider340B,
21-
NemotronModelProvider,
22-
)
15+
from megatron.bridge.models.nemotron.nemotron_bridge import NemotronBridge
16+
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider
2317

2418

2519
__all__ = [
20+
"NemotronBridge",
2621
"NemotronModelProvider",
27-
"Nemotron3ModelProvider4B",
28-
"Nemotron3ModelProvider8B",
29-
"Nemotron3ModelProvider22B",
30-
"Nemotron4ModelProvider15B",
31-
"Nemotron4ModelProvider340B",
3222
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from megatron.core.models.gpt.gpt_model import GPTModel
17+
from transformers import NemotronForCausalLM
18+
19+
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
20+
from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge
21+
from megatron.bridge.models.conversion.param_mapping import (
22+
AutoMapping,
23+
QKVMapping,
24+
)
25+
from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM
26+
from megatron.bridge.models.nemotron.nemotron_provider import NemotronModelProvider
27+
28+
29+
@MegatronModelBridge.register_bridge(source=NemotronForCausalLM, target=GPTModel)
30+
class NemotronBridge(MegatronModelBridge):
31+
"""
32+
Megatron Bridge for Nemotron Causal LM.
33+
34+
As a user you would not use this bridge directly, but through `AutoBridge`.
35+
36+
Example:
37+
>>> from megatron.bridge import AutoBridge
38+
>>> bridge = AutoBridge.from_hf_pretrained("nvidia/Nemotron-4-340B-Instruct")
39+
>>> provider = bridge.to_megatron_provider()
40+
"""
41+
42+
def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronModelProvider:
43+
hf_config = hf_pretrained.config
44+
45+
provider = NemotronModelProvider(
46+
num_layers=hf_config.num_hidden_layers,
47+
hidden_size=hf_config.hidden_size,
48+
ffn_hidden_size=hf_config.intermediate_size,
49+
num_attention_heads=hf_config.num_attention_heads,
50+
init_method_std=hf_config.initializer_range,
51+
layernorm_epsilon=hf_config.norm_eps,
52+
num_query_groups=hf_config.num_key_value_heads,
53+
seq_length=hf_config.max_position_embeddings,
54+
rotary_base=hf_config.rope_theta,
55+
rotary_percent=hf_config.partial_rotary_factor,
56+
kv_channels=getattr(hf_config, "head_dim", None),
57+
make_vocab_size_divisible_by=self.make_vocab_size_divisible_by(hf_config.vocab_size),
58+
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
59+
fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16),
60+
bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16),
61+
params_dtype=self.dtype_from_hf(hf_config, default=torch.float32),
62+
generation_config=hf_pretrained.generation_config,
63+
vocab_size=hf_config.vocab_size,
64+
)
65+
66+
return provider
67+
68+
def mapping_registry(self) -> MegatronMappingRegistry:
69+
# Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format
70+
# First create simple 1:1 parameter mappings using a dictionary for readability
71+
72+
# Dictionary maps Megatron parameter names -> HF parameter names
73+
# Supports wildcard (*) patterns for layer-specific parameters
74+
param_mappings = {
75+
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
76+
"output_layer.weight": "lm_head.weight",
77+
"decoder.final_layernorm.weight": "model.norm.weight",
78+
"decoder.final_layernorm.bias": "model.norm.bias",
79+
"decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight",
80+
"decoder.layers.*.self_attention.linear_qkv.layer_norm_bias": "model.layers.*.input_layernorm.bias",
81+
"decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight",
82+
"decoder.layers.*.mlp.linear_fc1.layer_norm_bias": "model.layers.*.post_attention_layernorm.bias",
83+
"decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight",
84+
"decoder.layers.*.mlp.linear_fc1.weight": "model.layers.*.mlp.up_proj.weight",
85+
"decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight",
86+
}
87+
88+
mapping_list = []
89+
# Convert each dictionary entry to AutoMapping(megatron_param, hf_param)
90+
for megatron_param, hf_param in param_mappings.items():
91+
mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param))
92+
93+
# Add special mappings that require parameter concatenation/transformation
94+
mapping_list.extend(
95+
[
96+
# QKV: Combine separate Q, K, V matrices into single QKV matrix
97+
QKVMapping(
98+
megatron_param="decoder.layers.*.self_attention.linear_qkv.weight",
99+
q="model.layers.*.self_attn.q_proj.weight",
100+
k="model.layers.*.self_attn.k_proj.weight",
101+
v="model.layers.*.self_attn.v_proj.weight",
102+
),
103+
]
104+
)
105+
106+
return MegatronMappingRegistry(*mapping_list)

src/megatron/bridge/models/nemotron/nemotron_provider.py

Lines changed: 1 addition & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from dataclasses import dataclass, field
17-
from typing import Callable, Optional
17+
from typing import Callable
1818

1919
import torch
2020

@@ -50,93 +50,3 @@ class NemotronModelProvider(GPTModelProvider):
5050
layernorm_zero_centered_gamma: bool = True
5151
cross_entropy_loss_fusion: bool = True
5252
apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion)
53-
54-
# Nemotron3Config4B as default configs
55-
num_layers: int = 32
56-
seq_length: int = 4096
57-
hidden_size: int = 3072
58-
ffn_hidden_size: int = 9216
59-
num_attention_heads: int = 24
60-
num_query_groups: Optional[int] = 8
61-
kv_channels: Optional[int] = 128
62-
init_method_std: float = 0.0134
63-
64-
65-
@dataclass
66-
class Nemotron3ModelProvider4B(NemotronModelProvider):
67-
"""
68-
Configuration class for the Nemotron3 4B model, inheriting from NemotronModelProvider.
69-
"""
70-
71-
num_layers: int = 32
72-
seq_length: int = 4096
73-
hidden_size: int = 3072
74-
ffn_hidden_size: int = 9216
75-
num_attention_heads: int = 24
76-
num_query_groups: int = 8
77-
kv_channels: Optional[int] = 128
78-
init_method_std: float = 0.0134
79-
80-
81-
@dataclass
82-
class Nemotron3ModelProvider8B(NemotronModelProvider):
83-
"""
84-
Configuration class for the Nemotron3 8B model, inheriting from NemotronModelProvider.
85-
"""
86-
87-
num_layers: int = 32
88-
seq_length: int = 4096
89-
hidden_size: int = 4096
90-
ffn_hidden_size: int = 16384
91-
num_attention_heads: int = 32
92-
num_query_groups: Optional[int] = None
93-
kv_channels: Optional[int] = None
94-
init_method_std: float = 0.010
95-
96-
97-
@dataclass
98-
class Nemotron3ModelProvider22B(NemotronModelProvider):
99-
"""
100-
Configuration class for the Nemotron3 22B model, inheriting from NemotronModelProvider.
101-
"""
102-
103-
num_layers: int = 40
104-
seq_length: int = 4096
105-
hidden_size: int = 6144
106-
ffn_hidden_size: int = 24576
107-
num_attention_heads: int = 48
108-
num_query_groups: Optional[int] = None
109-
kv_channels: Optional[int] = None
110-
init_method_std: float = 0.008
111-
112-
113-
@dataclass
114-
class Nemotron4ModelProvider15B(NemotronModelProvider):
115-
"""
116-
Configuration class for the Nemotron4 15B model, inheriting from NemotronModelProvider.
117-
"""
118-
119-
num_layers: int = 32
120-
seq_length: int = 4096
121-
hidden_size: int = 6144
122-
ffn_hidden_size: int = 24576
123-
num_attention_heads: int = 48
124-
num_query_groups: Optional[int] = 8
125-
kv_channels: Optional[int] = None
126-
init_method_std: float = 0.0134
127-
128-
129-
@dataclass
130-
class Nemotron4ModelProvider340B(NemotronModelProvider):
131-
"""
132-
Configuration class for the Nemotron4 340B model, inheriting from NemotronModelProvider.
133-
"""
134-
135-
num_layers: int = 96
136-
seq_length: int = 4096
137-
hidden_size: int = 18432
138-
ffn_hidden_size: int = 73728
139-
num_attention_heads: int = 96
140-
num_query_groups: Optional[int] = 8
141-
kv_channels: Optional[int] = None
142-
init_method_std: float = 0.0063

0 commit comments

Comments
 (0)