Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/megatron/bridge/models/nemotron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from megatron.bridge.models.nemotron.nemotron_provider import (
Nemotron3ModelProvider4B,
Nemotron3ModelProvider8B,
Nemotron3ModelProvider22B,
Nemotron4ModelProvider15B,
Nemotron4ModelProvider340B,
NemotronModelProvider,
)


__all__ = [
"NemotronModelProvider",
"Nemotron3ModelProvider4B",
"Nemotron3ModelProvider8B",
"Nemotron3ModelProvider22B",
"Nemotron4ModelProvider15B",
"Nemotron4ModelProvider340B",
]
142 changes: 142 additions & 0 deletions src/megatron/bridge/models/nemotron/nemotron_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from dataclasses import dataclass, field
from typing import Callable, Optional

import torch

from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.utils import fusions


logger = logging.getLogger(__name__)


def squared_relu(x):
"""Squared ReLU activation function."""
return torch.pow(torch.nn.functional.relu(x), 2)


@dataclass
class NemotronModelProvider(GPTModelProvider):
"""Configuration class for Nemotron models."""

# configs that are common across model sizes
normalization: str = "LayerNorm"
activation_func: Callable = squared_relu
position_embedding_type: str = "rope"
share_embeddings_and_output_weights: bool = False
add_bias_linear: bool = False

hidden_dropout: float = 0.0
attention_dropout: float = 0.0
rotary_percent: float = 0.5
masked_softmax_fusion: bool = field(default_factory=fusions.can_enable_masked_softmax_fusion)
persist_layer_norm: bool = True
bias_dropout_add_fusion: bool = False
layernorm_zero_centered_gamma: bool = True
cross_entropy_loss_fusion: bool = True
apply_rope_fusion: bool = field(default_factory=fusions.can_enable_apply_rope_fusion)

# Nemotron3Config4B as default configs
num_layers: int = 32
seq_length: int = 4096
hidden_size: int = 3072
ffn_hidden_size: int = 9216
num_attention_heads: int = 24
num_query_groups: Optional[int] = 8
kv_channels: Optional[int] = 128
init_method_std: float = 0.0134


@dataclass
class Nemotron3ModelProvider4B(NemotronModelProvider):
"""
Configuration class for the Nemotron3 4B model, inheriting from NemotronModelProvider.
"""

num_layers: int = 32
seq_length: int = 4096
hidden_size: int = 3072
ffn_hidden_size: int = 9216
num_attention_heads: int = 24
num_query_groups: int = 8
kv_channels: Optional[int] = 128
init_method_std: float = 0.0134


@dataclass
class Nemotron3ModelProvider8B(NemotronModelProvider):
"""
Configuration class for the Nemotron3 8B model, inheriting from NemotronModelProvider.
"""

num_layers: int = 32
seq_length: int = 4096
hidden_size: int = 4096
ffn_hidden_size: int = 16384
num_attention_heads: int = 32
num_query_groups: Optional[int] = None
kv_channels: Optional[int] = None
init_method_std: float = 0.010


@dataclass
class Nemotron3ModelProvider22B(NemotronModelProvider):
"""
Configuration class for the Nemotron3 22B model, inheriting from NemotronModelProvider.
"""

num_layers: int = 40
seq_length: int = 4096
hidden_size: int = 6144
ffn_hidden_size: int = 24576
num_attention_heads: int = 48
num_query_groups: Optional[int] = None
kv_channels: Optional[int] = None
init_method_std: float = 0.008


@dataclass
class Nemotron4ModelProvider15B(NemotronModelProvider):
"""
Configuration class for the Nemotron4 15B model, inheriting from NemotronModelProvider.
"""

num_layers: int = 32
seq_length: int = 4096
hidden_size: int = 6144
ffn_hidden_size: int = 24576
num_attention_heads: int = 48
num_query_groups: Optional[int] = 8
kv_channels: Optional[int] = None
init_method_std: float = 0.0134


@dataclass
class Nemotron4ModelProvider340B(NemotronModelProvider):
"""
Configuration class for the Nemotron4 340B model, inheriting from NemotronModelProvider.
"""

num_layers: int = 96
seq_length: int = 4096
hidden_size: int = 18432
ffn_hidden_size: int = 73728
num_attention_heads: int = 96
num_query_groups: Optional[int] = 8
kv_channels: Optional[int] = None
init_method_std: float = 0.0063
15 changes: 15 additions & 0 deletions src/megatron/bridge/recipes/nemotron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Nemotron recipes for Megatron-Bridge."""
207 changes: 207 additions & 0 deletions src/megatron/bridge/recipes/nemotron/nemotron3_22b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

import torch

from megatron.bridge.models.nemotron import Nemotron3ModelProvider22B
from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing
from megatron.bridge.recipes.utils.pretrain_utils import (
create_checkpoint_config,
create_dataset_config,
create_ddp_config,
create_logger_config,
create_rng_config,
create_tokenizer_config,
create_training_config,
setup_output_dirs,
)
from megatron.bridge.training.comm_overlap import (
CommOverlapConfig,
)
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig


def model_config(
tensor_parallelism: int = 2,
pipeline_parallelism: int = 4,
pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 10,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
sequence_length: int = 4096,
) -> Nemotron3ModelProvider22B:
"""
Configure the Nemotron3 22B model.

Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
sequence_length (int): Sequence length for the model.

Returns:
Nemotron3ModelProvider22B: Configuration for the Nemotron3 22B model.
"""
return Nemotron3ModelProvider22B(
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_dtype,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
seq_length=sequence_length,
)


def pretrain_config(
dir: Optional[str] = None,
name: str = "default",
# Dataset configuration
data_paths: Optional[List[str]] = None,
data_args_path: Optional[str] = None,
train_data_path: Optional[List[str]] = None,
valid_data_path: Optional[List[str]] = None,
test_data_path: Optional[List[str]] = None,
per_split_data_args_path: Optional[str] = None,
mock: bool = False,
# Model configuration
tensor_parallelism: int = 2,
pipeline_parallelism: int = 4,
pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 10,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
sequence_length: int = 4096,
# Training hyperparameters
train_iters: int = 300000,
global_batch_size: int = 32,
micro_batch_size: int = 1,
lr: float = 1e-4,
min_lr: float = 1e-5,
lr_warmup_iters: int = 500,
# Precision recipe
precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed",
comm_overlap_config: Optional[CommOverlapConfig] = None,
) -> ConfigContainer:
"""
Create a pre-training configuration for Nemotron3 22B model.

Args:
dir (Optional[str]): Base directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used.
data_args_path (Optional[str]): Path to file containing data arguments.
train_data_path (Optional[List[str]]): List of training data paths.
valid_data_path (Optional[List[str]]): List of validation data paths.
test_data_path (Optional[List[str]]): List of test data paths.
per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration.
mock (bool): Whether to use mock data. If True, ignores data_paths.
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism to be passed to model_config.
sequence_parallelism (bool): Whether to use sequence parallelism.
sequence_length (int): Sequence length for the model.
train_iters (int): Total number of training iterations.
global_batch_size (int): Global batch size for training.
micro_batch_size (int): Micro batch size for training.
lr (float): Learning rate.
min_lr (float): Minimum learning rate for cosine decay.
lr_warmup_iters (int): Number of warmup iterations for the learning rate.
precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model.
comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration for the model.

Returns:
ConfigContainer: Configuration for pre-training.
"""
# Set up output directories
run_output_dir, checkpoint_dir, tensorboard_dir = setup_output_dirs(dir, name)

# Create model configuration
model_cfg = model_config(
tensor_parallelism=tensor_parallelism,
pipeline_parallelism=pipeline_parallelism,
pipeline_parallelism_dtype=pipeline_parallelism_dtype,
virtual_pipeline_parallelism=virtual_pipeline_parallelism,
context_parallelism=context_parallelism,
sequence_parallelism=sequence_parallelism,
sequence_length=sequence_length,
)

# Create optimizer and scheduler configurations
opt_config, scheduler = distributed_fused_adam_with_cosine_annealing(
max_lr=lr,
min_lr=min_lr,
lr_warmup_iters=lr_warmup_iters,
lr_decay_iters=train_iters,
)

# Create dataset configuration
dataset_cfg = create_dataset_config(
sequence_length=sequence_length,
data_paths=data_paths,
data_args_path=data_args_path,
train_data_path=train_data_path,
valid_data_path=valid_data_path,
test_data_path=test_data_path,
per_split_data_args_path=per_split_data_args_path,
mock=mock,
)

# Create communication overlap configuration if not provided
if comm_overlap_config is None:
comm_overlap_config = CommOverlapConfig(
tp_comm_overlap=True,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=22,
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
)

final_precision_config = precision_config
if isinstance(precision_config, str):
# Mixed precision configuration
from megatron.bridge.training.mixed_precision import get_mixed_precision_config

final_precision_config = get_mixed_precision_config(precision_config)

final_precision_config.grad_reduce_in_fp32 = False

# Config Container
cfg = ConfigContainer(
model=model_cfg,
train=create_training_config(
train_iters=train_iters,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
),
optimizer=opt_config,
scheduler=scheduler,
ddp=create_ddp_config(),
dataset=dataset_cfg,
logger=create_logger_config(tensorboard_dir=tensorboard_dir),
tokenizer=create_tokenizer_config(),
checkpoint=create_checkpoint_config(checkpoint_dir=checkpoint_dir),
rng=create_rng_config(),
comm_overlap=comm_overlap_config,
mixed_precision=final_precision_config,
)

return cfg
Loading
Loading