Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo/tron/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ class GPTDatasetConfig(MCoreGPTDatasetConfig):
data_sharding: bool = True
"""Disable data sharding."""

create_attention_mask_in_dataloader: bool = True
"""If set, do not create attention_masks in dataloader."""

def __post_init__(self) -> None:
super(MCoreGPTDatasetConfig, self).__post_init__()

Expand Down
86 changes: 15 additions & 71 deletions nemo/tron/examples/lingua-1b_dclm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from functools import partial

import torch
import torch.distributed
from megatron.core import mpu
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig

from nemo.collections import llm
from nemo.collections.llm.gpt.model.base import gpt_data_step
from nemo.tron.api import megatron_pretrain
from nemo.tron.config import (
CheckpointConfig,
Expand All @@ -21,74 +32,7 @@
TrainingConfig,
)
from nemo.tron.data.dataset import get_blend_and_blend_per_split
from nemo.tron.state import GlobalState

# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2


def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.

Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses

Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
state = GlobalState()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])

if state.cfg.model_config.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())

# Reduce loss for logging.
reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group())

local_num_tokens = loss[1].clone().detach().to(torch.int)
return (
loss[0] * state.cfg.model_config.context_parallel_size,
local_num_tokens,
{"lm loss": (reporting_loss[0], reporting_loss[1])},
)


def forward_step(data_iterator, model):
"""Forward training step.

Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
timers = GlobalState().timers

# Get the batch.
timers("batch-generator", log_level=2).start()
batch = gpt_data_step(data_iterator)
if "attention_mask" not in batch:
batch["attention_mask"] = None

tokens, labels, loss_mask, attention_mask, position_ids = (
batch["tokens"],
batch["labels"],
batch["loss_mask"],
batch["attention_mask"],
batch["position_ids"],
)
timers("batch-generator").stop()

output_tensor = model(tokens, position_ids, attention_mask, labels=labels)

return output_tensor, partial(loss_func, loss_mask)

from nemo.tron.llm.gpt import forward_step

if __name__ == "__main__":
global_batch_size = 256
Expand Down
13 changes: 13 additions & 0 deletions nemo/tron/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
62 changes: 62 additions & 0 deletions nemo/tron/llm/gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Iterable

from megatron.core import parallel_state
from megatron.core.models.gpt import GPTModel
from megatron.core.utils import get_batch_on_this_cp_rank

from nemo.tron.config import ConfigContainer
from nemo.tron.llm.utils import get_batch_on_this_tp_rank
from nemo.tron.losses import masked_next_token_loss
from nemo.tron.state import GlobalState


def get_batch(data_iterator, cfg: ConfigContainer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add typehint + docs for the return value

"""Generate a batch."""

if (not parallel_state.is_pipeline_first_stage()) and (not parallel_state.is_pipeline_last_stage()):
return None, None, None, None, None

# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(data_iterator, cfg)

# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)

return batch.values()


def forward_step(state: GlobalState, data_iterator: Iterable, model: GPTModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, the return type will be helpful

"""Forward training step.

Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""

timers = state.cfg.model_config.timers
straggler_timer = state.straggler_timer

timers("batch-generator", log_level=2).start()
with straggler_timer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator, state.cfg)
timers("batch-generator").stop()

with straggler_timer:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)

return output_tensor, partial(masked_next_token_loss, loss_mask)
129 changes: 129 additions & 0 deletions nemo/tron/llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Iterable
import torch
from megatron.core import parallel_state
from nemo.tron.config import ConfigContainer


def get_batch_on_this_tp_rank(data_iterator: Iterable, cfg: ConfigContainer) -> Dict[str, torch.Tensor]:
def _broadcast(item):
if item is not None:
torch.distributed.broadcast(
item,
parallel_state.get_tensor_model_parallel_src_rank(),
group=parallel_state.get_tensor_model_parallel_group(),
)

if parallel_state.get_tensor_model_parallel_rank() == 0:
if data_iterator is not None:
data = next(data_iterator)
else:
data = None

batch = {
"tokens": data["tokens"].cuda(non_blocking=True),
"labels": data["labels"].cuda(non_blocking=True),
"loss_mask": data["loss_mask"].cuda(non_blocking=True),
"attention_mask": None if "attention_mask" not in data else data["attention_mask"].cuda(non_blocking=True),
"position_ids": data["position_ids"].cuda(non_blocking=True),
}

if cfg.model_config.pipeline_model_parallel_size == 1:
_broadcast(batch["tokens"])
_broadcast(batch["labels"])
_broadcast(batch["loss_mask"])
_broadcast(batch["attention_mask"])
_broadcast(batch["position_ids"])

elif parallel_state.is_pipeline_first_stage():
_broadcast(batch["tokens"])
_broadcast(batch["attention_mask"])
_broadcast(batch["position_ids"])

elif parallel_state.is_pipeline_last_stage():
_broadcast(batch["labels"])
_broadcast(batch["loss_mask"])
_broadcast(batch["attention_mask"])

else:
mbs = cfg.train_config.micro_batch_size
seq_length = cfg.model_config.seq_length
tokens = torch.empty(
(mbs, seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
labels = torch.empty(
(mbs, seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)
loss_mask = torch.empty(
(mbs, seq_length),
dtype=torch.float32,
device=torch.cuda.current_device(),
)
if cfg.dataset_config.create_attention_mask_in_dataloader:
attention_mask = torch.empty(
(
mbs,
1,
seq_length,
seq_length,
),
dtype=torch.bool,
device=torch.cuda.current_device(),
)
else:
attention_mask = None
position_ids = torch.empty(
(mbs, seq_length),
dtype=torch.int64,
device=torch.cuda.current_device(),
)

if cfg.model_config.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)

elif parallel_state.is_pipeline_first_stage():
labels = None
loss_mask = None

_broadcast(tokens)
_broadcast(attention_mask)
_broadcast(position_ids)

elif parallel_state.is_pipeline_last_stage():
tokens = None
position_ids = None

_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)

batch = {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"attention_mask": attention_mask,
"position_ids": position_ids,
}

return batch
Loading
Loading