Skip to content

Commit 58f7c4c

Browse files
ashors1ananthsubyuki-97
authored
feat: refactor mcore train/forward utilities (#1654)
Signed-off-by: ashors1 <ashors@nvidia.com> Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com> Co-authored-by: Ananth Subramaniam <ansubramania@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com>
1 parent 8ef0de9 commit 58f7c4c

File tree

7 files changed

+2095
-551
lines changed

7 files changed

+2095
-551
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 2 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -11,28 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from functools import partial
15-
from typing import Any, Iterator, Optional
14+
15+
from typing import Any, Optional
1616

1717
import torch
1818
import torch.distributed as dist
19-
from megatron.bridge.training.state import GlobalState
20-
from megatron.core.models.gpt import GPTModel
21-
from megatron.core.parallel_state import (
22-
get_context_parallel_group,
23-
get_context_parallel_world_size,
24-
get_tensor_model_parallel_group,
25-
get_tensor_model_parallel_rank,
26-
)
2719
from megatron.core.transformer.moe.moe_utils import (
2820
clear_aux_losses_tracker,
2921
get_moe_layer_wise_logging_tracker,
3022
reduce_aux_losses_tracker_across_ranks,
3123
)
3224

33-
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
34-
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
35-
3625

3726
def _round_up_to_multiple(value: int, multiple: int) -> int:
3827
return (
@@ -42,119 +31,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int:
4231
)
4332

4433

45-
def forward_step_arbitrary_loss(
46-
state: GlobalState,
47-
global_valid_seqs: torch.Tensor,
48-
global_valid_toks: torch.Tensor,
49-
data_iterator: Iterator[BatchedDataDict[Any]],
50-
model: GPTModel,
51-
loss_fn: LossFunction,
52-
pack_sequences: bool = False,
53-
defer_fp32_logits: Optional[bool] = None,
54-
cp_normalize: bool = True,
55-
policy_cfg: Optional[dict] = None,
56-
):
57-
"""Forward training step with support for packed sequences and context parallelism.
58-
59-
Args:
60-
state (GlobalState): Global state for the run
61-
global_valid_seqs: Global count of valid sequences
62-
global_valid_toks: Global count of valid tokens
63-
data_iterator: Input data iterator
64-
model (GPTModel): The GPT Model
65-
loss_fn (LossFunction): Loss function to apply
66-
pack_sequences (bool): Whether to pack sequences for efficiency
67-
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
68-
cp_normalize (bool): Whether to normalize the loss by the cp_size
69-
policy_cfg (Optional[dict]): Policy configuration containing generation parameters
70-
71-
Notes on packed sequences with context parallelism (CP):
72-
- When CP > 1, each sequence is padded to a multiple of (cp_size * 2)
73-
- The factor of 2 ensures load balancing for causal attention
74-
- cu_seqlens tracks actual sequence boundaries
75-
- cu_seqlens_padded tracks padded sequence boundaries for CP
76-
- Requires TransformerEngine >= 1.10 for CP support
77-
"""
78-
straggler_timer = state.straggler_timer
79-
80-
# Get the pre-processed microbatch from the iterator
81-
processed_mb = next(data_iterator)
82-
83-
# Extract the processed components
84-
data_dict = processed_mb.data_dict
85-
input_ids = processed_mb.input_ids
86-
input_ids_cp_sharded = processed_mb.input_ids_cp_sharded
87-
attention_mask = processed_mb.attention_mask
88-
position_ids = processed_mb.position_ids
89-
packed_seq_params = processed_mb.packed_seq_params
90-
cu_seqlens_padded = processed_mb.cu_seqlens_padded
91-
92-
multimodal_data = data_dict.get_multimodal_dict(
93-
as_tensors=True, device=input_ids_cp_sharded.device
94-
)
95-
if len(multimodal_data) > 0:
96-
position_ids = None
97-
98-
additional_kwargs = {}
99-
# Mamba models currently do not support packed_seq_params
100-
if packed_seq_params is not None:
101-
additional_kwargs["packed_seq_params"] = packed_seq_params
102-
103-
if defer_fp32_logits:
104-
additional_kwargs["fp32_output"] = False
105-
106-
with straggler_timer:
107-
output_tensor = model(
108-
input_ids=input_ids_cp_sharded,
109-
position_ids=position_ids,
110-
attention_mask=attention_mask,
111-
**additional_kwargs,
112-
**multimodal_data,
113-
)
114-
115-
# Apply temperature scaling to logits for training
116-
# This matches the dtensor worker's _apply_temperature_scaling in the train method
117-
if (
118-
policy_cfg is not None
119-
and "generation" in policy_cfg
120-
and policy_cfg["generation"] is not None
121-
):
122-
output_tensor.div_(policy_cfg["generation"]["temperature"])
123-
124-
# Unpack the output tensor if we did packed sequences
125-
if pack_sequences and packed_seq_params is not None:
126-
# remove padding
127-
loss_fn = SequencePackingLossWrapper(
128-
loss_fn=loss_fn,
129-
cu_seqlens_q=packed_seq_params.cu_seqlens_q,
130-
cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded,
131-
)
132-
133-
loss_data = data_dict
134-
135-
loss_fn_wrapped = partial(
136-
loss_fn,
137-
data=loss_data,
138-
global_valid_seqs=global_valid_seqs,
139-
global_valid_toks=global_valid_toks,
140-
vocab_parallel_rank=get_tensor_model_parallel_rank(),
141-
vocab_parallel_group=get_tensor_model_parallel_group(),
142-
context_parallel_group=get_context_parallel_group(),
143-
)
144-
145-
if cp_normalize:
146-
cp_size = get_context_parallel_world_size()
147-
orig_loss_fn_wrapped = loss_fn_wrapped
148-
149-
def _div_by_cp_size(*args, **kwargs):
150-
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
151-
return loss / cp_size, metrics
152-
153-
loss_fn_wrapped = _div_by_cp_size
154-
155-
return output_tensor, loss_fn_wrapped
156-
157-
15834
def broadcast_tensor(
15935
tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup
16036
) -> torch.Tensor:

nemo_rl/models/megatron/data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from contextlib import nullcontext
1516
from dataclasses import dataclass
1617
from typing import Any, Iterator, Optional, Tuple
1718

@@ -211,7 +212,7 @@ def process_microbatch(
211212
pad_packed_seq_to_multiple_of: int = 1,
212213
pad_full_seq_to: Optional[int] = None,
213214
pack_sequences: bool = False,
214-
straggler_timer: StragglerDetector = None,
215+
straggler_timer: Optional[StragglerDetector] = None,
215216
) -> tuple[
216217
torch.Tensor,
217218
torch.Tensor,
@@ -221,7 +222,8 @@ def process_microbatch(
221222
Optional[torch.Tensor],
222223
]:
223224
"""Process a microbatch for Megatron model forward pass."""
224-
with straggler_timer(bdata=True):
225+
ctx = straggler_timer(bdata=True) if straggler_timer is not None else nullcontext()
226+
with ctx:
225227
input_ids = data_dict["input_ids"]
226228
attention_mask = None
227229
position_ids = None
@@ -294,15 +296,15 @@ def process_global_batch(
294296
*,
295297
batch_idx: int,
296298
batch_size: int,
297-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
299+
) -> dict[str, Any]:
298300
"""Process a global batch and compute normalization factors.
299301
300302
Args:
301-
data: Full dataset
303+
data: Full dataset to extract a batch from
304+
loss_fn: Loss function (used to check loss type for token-level validation)
305+
dp_group: Data parallel process group for all-reduce
302306
batch_idx: Index of batch to extract
303307
batch_size: Size of batch to extract
304-
loss_fn: Loss function (used to check loss type)
305-
dp_mesh: Data parallel mesh
306308
307309
Returns:
308310
Dictionary containing:
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Pipeline parallel utilities for Megatron models."""
16+
17+
from typing import Any, Optional
18+
19+
import torch
20+
from megatron.core.parallel_state import (
21+
get_pipeline_model_parallel_group,
22+
get_pipeline_model_parallel_last_rank,
23+
get_pipeline_model_parallel_world_size,
24+
is_pipeline_last_stage,
25+
)
26+
27+
28+
def broadcast_obj_from_pp_rank(obj: Any) -> Any:
29+
"""Broadcast an object across pipeline parallel ranks.
30+
31+
This utility function handles broadcasting an object from the rank that owns it
32+
to all other pipeline parallel ranks. If only one rank has the object (non-None),
33+
it will be broadcast to all other ranks.
34+
35+
Args:
36+
obj: The object to broadcast. Can be None on ranks that don't own it.
37+
38+
Returns:
39+
The object on all ranks (either the original or the broadcast copy).
40+
41+
Raises:
42+
ValueError: If the object doesn't exist on any pipeline parallel rank.
43+
"""
44+
pp_size = get_pipeline_model_parallel_world_size()
45+
pp_group = get_pipeline_model_parallel_group()
46+
47+
if pp_size == 1:
48+
return obj
49+
50+
# ------------------------------------------------------------------
51+
# 1. Gather presence flags from all PP ranks to find the source rank
52+
# ------------------------------------------------------------------
53+
has_obj = obj is not None
54+
obj_flags = [None] * pp_size
55+
torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group)
56+
57+
# ------------------------------------------------------------------
58+
# 2. Identify the owning rank (the only rank with True flag)
59+
# ------------------------------------------------------------------
60+
true_ranks = [rank for rank, flag in enumerate(obj_flags) if flag]
61+
if not true_ranks:
62+
raise ValueError("Object must exist on at least one PP rank")
63+
if len(true_ranks) > 1:
64+
raise ValueError(f"Object present on multiple PP ranks: {true_ranks}")
65+
src_rank = true_ranks[0]
66+
67+
# ------------------------------------------------------------------
68+
# 3. Broadcast the object from the source rank to all ranks
69+
# ------------------------------------------------------------------
70+
# Use broadcast_object_list which is more robust than all_gather_object
71+
obj_list = [obj]
72+
pp_ranks = torch.distributed.get_process_group_ranks(pp_group)
73+
global_src = pp_ranks[src_rank]
74+
torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group)
75+
76+
return obj_list[0]
77+
78+
79+
def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list:
80+
"""Broadcast loss metrics from the last pipeline stage to all stages.
81+
82+
This utility handles the common pattern where loss computation happens on the last
83+
pipeline stage and needs to be broadcast to all other stages.
84+
85+
Args:
86+
loss_metrics: List of loss metrics if on last stage, None otherwise
87+
88+
Returns:
89+
List of loss metrics on all ranks
90+
"""
91+
pp_group = get_pipeline_model_parallel_group()
92+
last_rank = get_pipeline_model_parallel_last_rank()
93+
94+
if is_pipeline_last_stage(ignore_virtual=True):
95+
metrics_to_broadcast = [loss_metrics]
96+
torch.distributed.broadcast_object_list(
97+
metrics_to_broadcast,
98+
src=last_rank,
99+
group=pp_group,
100+
)
101+
return loss_metrics
102+
else:
103+
metrics_to_broadcast = [None]
104+
torch.distributed.broadcast_object_list(
105+
metrics_to_broadcast,
106+
src=last_rank,
107+
group=pp_group,
108+
)
109+
return metrics_to_broadcast[0]
110+
111+
112+
def broadcast_tensors_from_last_stage(
113+
tensors: dict[str, Optional[torch.Tensor]],
114+
) -> dict[str, torch.Tensor]:
115+
"""Broadcast multiple tensors from the last pipeline stage to all stages.
116+
117+
Args:
118+
tensors: Dictionary mapping tensor names to tensors (None on non-last stages)
119+
pp_group: Pipeline parallel group (auto-detected if None)
120+
121+
Returns:
122+
Dictionary of broadcasted tensors on all ranks
123+
"""
124+
pp_group = get_pipeline_model_parallel_group()
125+
126+
from nemo_rl.models.megatron.common import broadcast_tensor
127+
128+
last_rank = get_pipeline_model_parallel_last_rank()
129+
current_rank = torch.distributed.get_rank()
130+
131+
broadcasted_tensors = {}
132+
133+
if is_pipeline_last_stage(ignore_virtual=True):
134+
# Broadcast tensors from last stage
135+
for name, tensor in tensors.items():
136+
if tensor is None:
137+
raise ValueError(
138+
f"Last PP stage must provide tensor '{name}' for broadcast."
139+
)
140+
broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group)
141+
else:
142+
# Receive tensors on other stages
143+
for name in tensors.keys():
144+
broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group)
145+
146+
return broadcasted_tensors

0 commit comments

Comments
 (0)