Skip to content

Commit 9c7cab8

Browse files
committed
Merge branch 'ashors/mcore-data' of github.com:NVIDIA-NeMo/RL into ashors/mcore-train
2 parents e04d2da + a2a8a51 commit 9c7cab8

File tree

5 files changed

+168
-30
lines changed

5 files changed

+168
-30
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,131 @@
2323
reduce_aux_losses_tracker_across_ranks,
2424
)
2525

26+
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
27+
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
28+
29+
2630
def _round_up_to_multiple(value: int, multiple: int) -> int:
2731
return (
2832
((value + multiple - 1) // multiple * multiple)
2933
if value % multiple != 0
3034
else value
3135
)
3236

37+
38+
def forward_step_arbitrary_loss(
39+
state: GlobalState,
40+
global_valid_seqs: torch.Tensor,
41+
global_valid_toks: torch.Tensor,
42+
data_iterator: Iterator[BatchedDataDict[Any]],
43+
model: GPTModel,
44+
loss_fn: LossFunction,
45+
pack_sequences: bool = False,
46+
defer_fp32_logits: Optional[bool] = None,
47+
cp_normalize: bool = True,
48+
policy_cfg: Optional[dict] = None,
49+
):
50+
"""Forward training step with support for packed sequences and context parallelism.
51+
52+
Args:
53+
state (GlobalState): Global state for the run
54+
global_valid_seqs: Global count of valid sequences
55+
global_valid_toks: Global count of valid tokens
56+
data_iterator: Input data iterator
57+
model (GPTModel): The GPT Model
58+
loss_fn (LossFunction): Loss function to apply
59+
pack_sequences (bool): Whether to pack sequences for efficiency
60+
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
61+
cp_normalize (bool): Whether to normalize the loss by the cp_size
62+
policy_cfg (Optional[dict]): Policy configuration containing generation parameters
63+
64+
Notes on packed sequences with context parallelism (CP):
65+
- When CP > 1, each sequence is padded to a multiple of (cp_size * 2)
66+
- The factor of 2 ensures load balancing for causal attention
67+
- cu_seqlens tracks actual sequence boundaries
68+
- cu_seqlens_padded tracks padded sequence boundaries for CP
69+
- Requires TransformerEngine >= 1.10 for CP support
70+
"""
71+
straggler_timer = state.straggler_timer
72+
73+
# Get the pre-processed microbatch from the iterator
74+
processed_mb = next(data_iterator)
75+
76+
# Extract the processed components
77+
data_dict = processed_mb.data_dict
78+
input_ids = processed_mb.input_ids
79+
input_ids_cp_sharded = processed_mb.input_ids_cp_sharded
80+
attention_mask = processed_mb.attention_mask
81+
position_ids = processed_mb.position_ids
82+
packed_seq_params = processed_mb.packed_seq_params
83+
cu_seqlens_padded = processed_mb.cu_seqlens_padded
84+
85+
multimodal_data = data_dict.get_multimodal_dict(
86+
as_tensors=True, device=input_ids_cp_sharded.device
87+
)
88+
if len(multimodal_data) > 0:
89+
position_ids = None
90+
91+
additional_kwargs = {}
92+
# Mamba models currently do not support packed_seq_params
93+
if packed_seq_params is not None:
94+
additional_kwargs["packed_seq_params"] = packed_seq_params
95+
96+
if defer_fp32_logits:
97+
additional_kwargs["fp32_output"] = False
98+
99+
with straggler_timer:
100+
output_tensor = model(
101+
input_ids=input_ids_cp_sharded,
102+
position_ids=position_ids,
103+
attention_mask=attention_mask,
104+
**additional_kwargs,
105+
**multimodal_data,
106+
)
107+
108+
# Apply temperature scaling to logits for training
109+
# This matches the dtensor worker's _apply_temperature_scaling in the train method
110+
if (
111+
policy_cfg is not None
112+
and "generation" in policy_cfg
113+
and policy_cfg["generation"] is not None
114+
):
115+
output_tensor.div_(policy_cfg["generation"]["temperature"])
116+
117+
# Unpack the output tensor if we did packed sequences
118+
if pack_sequences and packed_seq_params is not None:
119+
# remove padding
120+
loss_fn = SequencePackingLossWrapper(
121+
loss_fn=loss_fn,
122+
cu_seqlens_q=packed_seq_params.cu_seqlens_q,
123+
cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded,
124+
)
125+
126+
loss_data = data_dict
127+
128+
loss_fn_wrapped = partial(
129+
loss_fn,
130+
data=loss_data,
131+
global_valid_seqs=global_valid_seqs,
132+
global_valid_toks=global_valid_toks,
133+
vocab_parallel_rank=get_tensor_model_parallel_rank(),
134+
vocab_parallel_group=get_tensor_model_parallel_group(),
135+
context_parallel_group=get_context_parallel_group(),
136+
)
137+
138+
if cp_normalize:
139+
cp_size = get_context_parallel_world_size()
140+
orig_loss_fn_wrapped = loss_fn_wrapped
141+
142+
def _div_by_cp_size(*args, **kwargs):
143+
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
144+
return loss / cp_size, metrics
145+
146+
loss_fn_wrapped = _div_by_cp_size
147+
148+
return output_tensor, loss_fn_wrapped
149+
150+
33151
def broadcast_tensor(
34152
tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup
35153
) -> torch.Tensor:

nemo_rl/models/megatron/data.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
from typing import Any, Iterator, Optional, Tuple
1717

1818
import torch
19-
2019
from megatron.core.packed_seq_params import PackedSeqParams
2120
from megatron.core.parallel_state import (
2221
get_context_parallel_rank,
2322
get_context_parallel_world_size,
2423
)
2524
from megatron.training.utils import get_ltor_masks_and_position_ids
26-
from nemo_rl.models.megatron.common import _round_up_to_multiple
25+
2726
from nemo_rl.algorithms.interfaces import LossFunction, LossType
2827
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2928
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank
29+
from nemo_rl.models.megatron.common import _round_up_to_multiple
3030

3131

3232
@dataclass
@@ -45,6 +45,7 @@ class ProcessedMicrobatch:
4545
packed_seq_params: PackedSeqParams for sequence packing (None if not packing)
4646
cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing)
4747
"""
48+
4849
data_dict: BatchedDataDict[Any]
4950
input_ids: torch.Tensor
5051
input_ids_cp_sharded: torch.Tensor
@@ -192,6 +193,7 @@ def get_microbatch_iterator(
192193
padded_seq_length,
193194
)
194195

196+
195197
def process_microbatch(
196198
data_dict: BatchedDataDict[Any],
197199
seq_length_key: Optional[str] = None,
@@ -200,7 +202,7 @@ def process_microbatch(
200202
pad_full_seq_to: Optional[int] = None,
201203
pack_sequences: bool = False,
202204
):
203-
#with straggler_timer(bdata=True):
205+
# with straggler_timer(bdata=True):
204206
input_ids = data_dict["input_ids"]
205207
attention_mask = None
206208
position_ids = None
@@ -217,9 +219,7 @@ def process_microbatch(
217219
assert seq_length_key is not None, (
218220
"seq_length_key must be provided for packed sequences"
219221
)
220-
assert seq_length_key in data_dict, (
221-
f"{seq_length_key} not found in data_dict"
222-
)
222+
assert seq_length_key in data_dict, f"{seq_length_key} not found in data_dict"
223223

224224
# Get sequence lengths and context parallel size
225225
seq_lengths = data_dict[seq_length_key]
@@ -240,7 +240,7 @@ def process_microbatch(
240240
cp_rank=get_context_parallel_rank(),
241241
cp_size=get_context_parallel_world_size(),
242242
)
243-
243+
244244
# For packed sequences, position_ids and attention_mask are typically None
245245
# The PackedSeqParams handles all necessary sequence information
246246
position_ids = None
@@ -265,6 +265,7 @@ def process_microbatch(
265265
cu_seqlens_padded,
266266
)
267267

268+
268269
def process_global_batch(
269270
data: BatchedDataDict[Any],
270271
batch_idx: int,
@@ -301,6 +302,7 @@ def process_global_batch(
301302
global_valid_toks,
302303
)
303304

305+
304306
def _pack_sequences_for_megatron(
305307
input_ids: torch.Tensor,
306308
seq_lengths: torch.Tensor,
@@ -605,6 +607,7 @@ def _unpack_sequences_from_megatron(
605607

606608
return unpacked_output
607609

610+
608611
def check_sequence_dim(data: BatchedDataDict[Any]):
609612
# dim 1 is always assumed to be the sequence dim, sanity check this here
610613
sequence_dim = 1
@@ -614,4 +617,4 @@ def check_sequence_dim(data: BatchedDataDict[Any]):
614617
assert v.shape[sequence_dim] == seq_dim_size, (
615618
f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}"
616619
)
617-
return sequence_dim, seq_dim_size
620+
return sequence_dim, seq_dim_size

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,16 @@
9494
verify_right_padding,
9595
)
9696
from nemo_rl.models.generation.vllm.config import VllmConfig
97-
from nemo_rl.models.megatron.common import get_moe_metrics
97+
from nemo_rl.models.megatron.common import (
98+
broadcast_tensor,
99+
forward_step_arbitrary_loss,
100+
get_moe_metrics,
101+
)
102+
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
98103
from nemo_rl.models.megatron.data import (
99104
get_microbatch_iterator,
100105
process_global_batch,
101106
)
102-
from nemo_rl.models.megatron.pipeline_parallel import (
103-
broadcast_obj_from_pp_rank,
104-
broadcast_loss_metrics_from_last_stage,
105-
broadcast_tensors_from_last_stage,
106-
)
107-
from nemo_rl.models.megatron.train import (
108-
megatron_forward_backward,
109-
LossPostProcessor,
110-
LogprobsPostProcessor,
111-
TopkLogitsPostProcessor,
112-
)
113-
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
114107
from nemo_rl.models.policy import PolicyConfig
115108
from nemo_rl.models.policy.interfaces import (
116109
ColocatablePolicyInterface,

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,16 @@ def forward(
328328
output_tensor, wrapped_loss_fn = forward_with_post_processing_fn(
329329
data_iterator=make_processed_microbatch_iterator(
330330
iter([packed_data_dict]),
331-
cfg=cfg,
331+
cfg={
332+
"sequence_packing": {"enabled": True},
333+
"dynamic_batching": {"enabled": False},
334+
"megatron_cfg": {
335+
"tensor_model_parallel_size": 1,
336+
"sequence_parallel": False,
337+
"pipeline_model_parallel_size": 1,
338+
"context_parallel_size": cp_size,
339+
},
340+
},
332341
seq_length_key="input_lengths",
333342
pad_individual_seqs_to_multiple_of=pad_to_multiple,
334343
pad_packed_seq_to_multiple_of=1,

tests/unit/models/megatron/test_megatron_data.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def test_processed_microbatch_fields(self):
7373
assert microbatch.packed_seq_params == mock_packed_seq_params
7474
assert torch.equal(microbatch.cu_seqlens_padded, mock_cu_seqlens_padded)
7575

76+
7677
class TestCheckSequenceDim:
7778
"""Tests for check_sequence_dim function."""
7879

@@ -154,7 +155,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
154155

155156
# Create test data
156157
data_dict = MagicMock()
157-
input_ids = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]])
158+
input_ids = torch.tensor(
159+
[[1, 2, 3, 4, 5, 0, 0, 0, 0, 0], [6, 7, 8, 9, 10, 11, 12, 0, 0, 0]]
160+
)
158161
data_dict.__getitem__ = MagicMock(return_value=input_ids)
159162

160163
(
@@ -178,7 +181,9 @@ def test_process_microbatch_no_packing(self, mock_get_masks):
178181
mock_get_masks.assert_called_once()
179182

180183
@patch("nemo_rl.models.megatron.data.get_context_parallel_rank", return_value=0)
181-
@patch("nemo_rl.models.megatron.data.get_context_parallel_world_size", return_value=1)
184+
@patch(
185+
"nemo_rl.models.megatron.data.get_context_parallel_world_size", return_value=1
186+
)
182187
@patch("nemo_rl.models.megatron.data._pack_sequences_for_megatron")
183188
def test_process_microbatch_with_packing(
184189
self, mock_pack, mock_cp_world, mock_cp_rank
@@ -226,7 +231,7 @@ def test_process_microbatch_with_packing(
226231
assert attention_mask is None
227232
assert position_ids is None
228233
assert cu_seqlens_padded is not None
229-
234+
230235
# Verify pack was called
231236
mock_pack.assert_called_once()
232237

@@ -323,6 +328,7 @@ def test_process_global_batch_requires_sample_mask(self):
323328

324329
assert "sample_mask must be present" in str(exc_info.value)
325330

331+
326332
class TestGetMicrobatchIterator:
327333
"""Tests for get_microbatch_iterator function."""
328334

@@ -383,8 +389,13 @@ def test_get_microbatch_iterator_sequence_packing(
383389
mock_make_iterator.return_value = mock_iterator
384390

385391
mock_data = MagicMock()
386-
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([])
387-
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (10, 512)
392+
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter(
393+
[]
394+
)
395+
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (
396+
10,
397+
512,
398+
)
388399

389400
cfg = {
390401
"dynamic_batching": {"enabled": False},
@@ -473,8 +484,13 @@ def test_get_microbatch_iterator_auto_detects_seq_length_key(
473484
mock_make_iterator.return_value = mock_iterator
474485

475486
mock_data = MagicMock()
476-
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter([])
477-
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (5, 256)
487+
mock_data.make_microbatch_iterator_for_packable_sequences.return_value = iter(
488+
[]
489+
)
490+
mock_data.get_microbatch_iterator_for_packable_sequences_len.return_value = (
491+
5,
492+
256,
493+
)
478494

479495
cfg = {
480496
"dynamic_batching": {"enabled": False},
@@ -1677,4 +1693,3 @@ def test_get_pack_sequence_parameters_for_megatron(get_pack_sequence_parameters_
16771693
# Check that all workers succeeded
16781694
for i, result in enumerate(results):
16791695
assert result["success"], f"Worker {i} failed: {result['error']}"
1680-

0 commit comments

Comments
 (0)