Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def compare_models_one_step(args) -> None:
vocab_size = getattr(
tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000
)
hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.bfloat16)
hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.float32)

# Broadcast from rank 0 to all ranks
torch.distributed.broadcast(hf_next_token, 0)
Expand Down
106 changes: 106 additions & 0 deletions examples/recipes/nemotron_3/finetune_nemotron_3_nano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import sys
from typing import Tuple

import torch
from omegaconf import OmegaConf

from megatron.bridge.recipes.nemotronh.nemotron_3_nano import (
nemotron_3_nano_finetune_config as finetune_config,
)
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.finetune import finetune
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.utils.omegaconf_utils import (
apply_overrides,
create_omegaconf_dict_config,
parse_hydra_overrides,
)


logger: logging.Logger = logging.getLogger(__name__)


def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
"""Parse command line arguments, separating known script args from OmegaConf overrides."""
parser = argparse.ArgumentParser(
description="Finetune Nemotron 3 Nano model using Megatron-Bridge with YAML and CLI overrides",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--config-file",
type=str,
help="Path to the YAML OmegaConf override file.",
)
parser.add_argument("--peft", type=str, help="Type of PEFT to use")
parser.add_argument("--packed-sequence", action="store_true", help="Whether to use sequence packing")
parser.add_argument("--seq-length", type=int, default=2048, help="Sequence length")

# Parse known args for the script, remaining will be treated as overrides
args, cli_dotlist_overrides = parser.parse_known_args()
return args, cli_dotlist_overrides


def main() -> None:
"""
Entry point for the Nemotron 3 Nano finetuning script.
"""
args, cli_overrides = parse_cli_args()

cfg: ConfigContainer = finetune_config(
seq_length=args.seq_length, peft=args.peft, packed_sequence=args.packed_sequence
)
cfg.model.seq_length = args.seq_length

# Convert the initial Python dataclass to an OmegaConf DictConfig for merging
merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg)

# Load and merge YAML overrides if a config file is provided
if args.config_file:
logger.debug(f"Loading YAML overrides from: {args.config_file}")
if not os.path.exists(args.config_file):
logger.error(f"Override YAML file not found: {args.config_file}")
sys.exit(1)
yaml_overrides_omega = OmegaConf.load(args.config_file)
merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega)
logger.debug("YAML overrides merged successfully.")

# Apply command-line overrides using Hydra-style parsing
if cli_overrides:
logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}")
merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides)
logger.debug("Hydra-style command-line overrides applied successfully.")

# Apply the final merged OmegaConf configuration back to the original ConfigContainer
logger.debug("Applying final merged configuration back to Python ConfigContainer...")
final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True)
# Apply overrides while preserving excluded fields
apply_overrides(cfg, final_overrides_as_dict, excluded_fields)

# Start training
logger.debug("Starting finetuning...")
finetune(config=cfg, forward_step_func=forward_step)

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
105 changes: 105 additions & 0 deletions examples/recipes/nemotron_3/pretrain_nemotron_3_nano.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import sys
from typing import Tuple

import torch
from omegaconf import OmegaConf

from megatron.bridge.recipes.nemotronh.nemotron_3_nano import (
nemotron_3_nano_pretrain_config as pretrain_config,
)
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain
from megatron.bridge.training.utils.omegaconf_utils import (
apply_overrides,
create_omegaconf_dict_config,
parse_hydra_overrides,
)


logger: logging.Logger = logging.getLogger(__name__)


def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]:
"""Parse command line arguments, separating known script args from OmegaConf overrides."""
parser = argparse.ArgumentParser(
description="Pretrain Nemotron 3 Nano model using Megatron-Bridge with YAML and CLI overrides",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--config-file",
type=str,
help="Path to the YAML OmegaConf override file.",
)
parser.add_argument("--per-split-data-args-path", type=str, help="Path to the per split data args file.")
parser.add_argument("--tokenizer-model", type=str, help="Path to the tokenizer model file.")

# Parse known args for the script, remaining will be treated as overrides
args, cli_dotlist_overrides = parser.parse_known_args()
return args, cli_dotlist_overrides


def main() -> None:
"""
Entry point for the Nemotron 3 Nano pretraining script.
"""
args, cli_overrides = parse_cli_args()

cfg: ConfigContainer = pretrain_config(
per_split_data_args_path=args.per_split_data_args_path,
tokenizer_model=args.tokenizer_model,
)

# Convert the initial Python dataclass to an OmegaConf DictConfig for merging
merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg)

# Load and merge YAML overrides if a config file is provided
if args.config_file:
logger.debug(f"Loading YAML overrides from: {args.config_file}")
if not os.path.exists(args.config_file):
logger.error(f"Override YAML file not found: {args.config_file}")
sys.exit(1)
yaml_overrides_omega = OmegaConf.load(args.config_file)
merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega)
logger.debug("YAML overrides merged successfully.")

# Apply command-line overrides using Hydra-style parsing
if cli_overrides:
logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}")
merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides)
logger.debug("Hydra-style command-line overrides applied successfully.")

# Apply the final merged OmegaConf configuration back to the original ConfigContainer
logger.debug("Applying final merged configuration back to Python ConfigContainer...")
final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True)
# Apply overrides while preserving excluded fields
apply_overrides(cfg, final_overrides_as_dict, excluded_fields)

# Start training
logger.debug("Starting pretraining...")
pretrain(config=cfg, forward_step_func=forward_step)

if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions src/megatron/bridge/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
NemotronVLModel,
)
from megatron.bridge.models.nemotronh.nemotron_h_provider import (
Nemotron3NanoProvider,
NemotronHModel4BProvider,
NemotronHModel8BProvider,
NemotronHModel47BProvider,
Expand Down Expand Up @@ -326,6 +327,7 @@
"NemotronHModel56BProvider",
"NemotronNano9Bv2Provider",
"NemotronNano12Bv2Provider",
"Nemotron3NanoProvider",
"MambaModelProvider",
"MambaModelProvider1P3B",
"MambaModelProvider2P7B",
Expand Down
6 changes: 6 additions & 0 deletions src/megatron/bridge/models/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,12 @@ def get_model(
if (model_config.fp16 or model_config.bf16) and mixed_precision_wrapper is not None:
model = [mixed_precision_wrapper(model_config, model_module) for model_module in model]

# Maintain expert bias in float32 wrapped in Float16Module
for model_module in model:
for submodule in model_module.modules():
if hasattr(submodule, "_maintain_float32_expert_bias"):
submodule._maintain_float32_expert_bias()

if correct_amax_history_if_needed is not None:
correct_amax_history_if_needed(model)

Expand Down
2 changes: 2 additions & 0 deletions src/megatron/bridge/models/nemotronh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from megatron.bridge.models.nemotronh.nemotron_h_bridge import NemotronHBridge
from megatron.bridge.models.nemotronh.nemotron_h_provider import (
Nemotron3NanoProvider,
NemotronHModel4BProvider,
NemotronHModel8BProvider,
NemotronHModel47BProvider,
Expand Down Expand Up @@ -44,4 +45,5 @@
"NemotronHModel56BProvider",
"NemotronNano9Bv2Provider",
"NemotronNano12Bv2Provider",
"Nemotron3NanoProvider",
]
23 changes: 23 additions & 0 deletions src/megatron/bridge/models/nemotronh/nemotron_h_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ class NemotronHBridge(MegatronModelBridge):
def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronHModelProvider:
hf_config = hf_pretrained.config

configs = {}
# MoE configurations
if hasattr(hf_config, "n_routed_experts") and hf_config.n_routed_experts > 0:
configs.update(
{
"num_moe_experts": hf_config.n_routed_experts,
"moe_ffn_hidden_size": hf_config.moe_intermediate_size,
"moe_shared_expert_intermediate_size": hf_config.moe_shared_expert_intermediate_size,
"moe_router_topk": hf_config.num_experts_per_tok,
"moe_router_num_groups": hf_config.n_group,
"moe_router_group_topk": hf_config.topk_group,
"moe_router_topk_scaling_factor": hf_config.routed_scaling_factor,
}
)

return NemotronHModelProvider(
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
Expand All @@ -78,6 +93,7 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> NemotronHModelPr
mamba_num_groups=hf_config.n_groups,
mamba_state_dim=hf_config.ssm_state_size,
add_qkv_bias=hf_config.attention_bias,
**configs,
)

def mapping_registry(self) -> MegatronMappingRegistry:
Expand Down Expand Up @@ -105,6 +121,13 @@ def mapping_registry(self) -> MegatronMappingRegistry:
# TODO (@maanug): need to find a way to prune the vocab padding from the vocab dimension for these params
"embedding.word_embeddings.weight": "backbone.embeddings.weight",
"output_layer.weight": "lm_head.weight",
# MoE layers
"decoder.layers.*.mlp.router.weight": "backbone.layers.*.mixer.gate.weight",
"decoder.layers.*.mlp.router.expert_bias": "backbone.layers.*.mixer.gate.e_score_correction_bias",
"decoder.layers.*.mlp.experts.linear_fc1.weight*": "backbone.layers.*.mixer.experts.*.up_proj.weight",
"decoder.layers.*.mlp.experts.linear_fc2.weight*": "backbone.layers.*.mixer.experts.*.down_proj.weight",
"decoder.layers.*.mlp.shared_experts.linear_fc1.weight": "backbone.layers.*.mixer.shared_experts.up_proj.weight",
"decoder.layers.*.mlp.shared_experts.linear_fc2.weight": "backbone.layers.*.mixer.shared_experts.down_proj.weight",
}

mapping_list = []
Expand Down
35 changes: 35 additions & 0 deletions src/megatron/bridge/models/nemotronh/nemotron_h_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ class NemotronHModelProvider(MambaModelProvider):
first_last_layers_bf16: bool = True
is_hybrid_model: bool = True

# MoE
moe_aux_loss_coeff: float = 0.0001
moe_router_score_function: str = "sigmoid"
moe_router_enable_expert_bias: bool = True
moe_router_load_balancing_type: str = "seq_aux_loss"
moe_router_dtype: str = "fp32"
moe_grouped_gemm: bool = True
moe_token_dispatcher_type: str = "alltoall"
moe_permute_fusion: bool = True
moe_shared_expert_overlap: bool = True


@dataclass
class NemotronHModelProvider4B(NemotronHModelProvider):
Expand Down Expand Up @@ -138,6 +149,30 @@ class NemotronNanoModelProvider12Bv2(NemotronHModelProvider):
seq_length: int = 131072


@dataclass
class Nemotron3NanoProvider(NemotronHModelProvider):
"""Configuration for a 3B parameter Nemotron 3 Nano model."""

seq_length: int = 262144
num_query_groups: int = 2
hybrid_override_pattern: str = "MEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEM*EMEMEMEM*EMEMEMEME"
num_layers: int = 52
hidden_size: int = 2688
mamba_num_heads: int = 64
kv_channels: int = 128
mamba_state_dim: int = 128
ffn_hidden_size: int = 1856
num_attention_heads: int = 32
mamba_head_dim: int = 64
num_moe_experts: int = 128
moe_ffn_hidden_size: int = 1856
moe_shared_expert_intermediate_size: int = 3712 # 1856 * 2 shared expert
moe_router_topk: int = 6
moe_router_topk_scaling_factor: float = 2.5
moe_router_num_groups: int = 1
moe_router_group_topk: int = 1


# -----------------------------------------------------------------------------
# Deprecated aliases (to be removed in a future release)
# -----------------------------------------------------------------------------
Expand Down
8 changes: 8 additions & 0 deletions src/megatron/bridge/recipes/nemotronh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
# limitations under the License.

# Nemotron Nano v2 models
# Nemotron 3 Nano models
from megatron.bridge.recipes.nemotronh.nemotron_3_nano import (
nemotron_3_nano_finetune_config,
nemotron_3_nano_pretrain_config,
)
from megatron.bridge.recipes.nemotronh.nemotron_nano_v2 import (
nemotron_nano_9b_v2_finetune_config,
nemotron_nano_9b_v2_pretrain_config,
Expand Down Expand Up @@ -48,4 +53,7 @@
"nemotron_nano_12b_v2_pretrain_config",
"nemotron_nano_9b_v2_finetune_config",
"nemotron_nano_12b_v2_finetune_config",
# Nemotron 3 Nano models
"nemotron_3_nano_pretrain_config",
"nemotron_3_nano_finetune_config",
]
Loading