Skip to content

Commit fc94f8b

Browse files
committed
lint and address comments
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent ea02f64 commit fc94f8b

File tree

5 files changed

+126
-76
lines changed

5 files changed

+126
-76
lines changed

nemo_rl/models/megatron/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_moe_layer_wise_logging_tracker,
3030
reduce_aux_losses_tracker_across_ranks,
3131
)
32+
3233
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
3334
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
3435

@@ -40,6 +41,7 @@ def _round_up_to_multiple(value: int, multiple: int) -> int:
4041
else value
4142
)
4243

44+
4345
def forward_step_arbitrary_loss(
4446
state: GlobalState,
4547
global_valid_seqs: torch.Tensor,

nemo_rl/models/megatron/data.py

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616
from typing import Any, Iterator, Optional, Tuple
1717

1818
import torch
19-
2019
from megatron.core.packed_seq_params import PackedSeqParams
2120
from megatron.core.parallel_state import (
2221
get_context_parallel_rank,
2322
get_context_parallel_world_size,
2423
)
24+
from megatron.core.utils import StragglerDetector
2525
from megatron.training.utils import get_ltor_masks_and_position_ids
26-
from nemo_rl.models.megatron.common import _round_up_to_multiple
26+
2727
from nemo_rl.algorithms.interfaces import LossFunction, LossType
2828
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2929
from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank
30+
from nemo_rl.models.megatron.common import _round_up_to_multiple
3031

3132

3233
@dataclass
@@ -45,6 +46,7 @@ class ProcessedMicrobatch:
4546
packed_seq_params: PackedSeqParams for sequence packing (None if not packing)
4647
cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing)
4748
"""
49+
4850
data_dict: BatchedDataDict[Any]
4951
input_ids: torch.Tensor
5052
input_ids_cp_sharded: torch.Tensor
@@ -60,6 +62,7 @@ def make_processed_microbatch_iterator(
6062
seq_length_key: Optional[str],
6163
pad_individual_seqs_to_multiple_of: int,
6264
pad_packed_seq_to_multiple_of: int,
65+
straggler_timer: StragglerDetector,
6366
pad_full_seq_to: Optional[int],
6467
) -> Iterator[ProcessedMicrobatch]:
6568
"""Wrap a raw microbatch iterator to yield processed microbatches.
@@ -100,6 +103,7 @@ def make_processed_microbatch_iterator(
100103
pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of,
101104
pad_full_seq_to=pad_full_seq_to,
102105
pack_sequences=pack_sequences,
106+
straggler_timer=straggler_timer,
103107
)
104108

105109
yield ProcessedMicrobatch(
@@ -117,6 +121,7 @@ def get_microbatch_iterator(
117121
data: BatchedDataDict[Any],
118122
cfg: dict[str, Any],
119123
mbs: int,
124+
straggler_timer: StragglerDetector,
120125
seq_length_key: Optional[str] = None,
121126
) -> Tuple[Iterator[ProcessedMicrobatch], int, int, int, int]:
122127
"""Create a processed microbatch iterator from a batch of data.
@@ -179,6 +184,7 @@ def get_microbatch_iterator(
179184
pad_individual_seqs_to_multiple_of=pad_factor,
180185
pad_packed_seq_to_multiple_of=pad_packed_seq_to_multiple_of,
181186
pad_full_seq_to=pad_full_seq_to,
187+
straggler_timer=straggler_timer,
182188
)
183189

184190
# Compute padded sequence length for pipeline parallelism
@@ -192,70 +198,80 @@ def get_microbatch_iterator(
192198
padded_seq_length,
193199
)
194200

201+
195202
def process_microbatch(
196203
data_dict: BatchedDataDict[Any],
197204
seq_length_key: Optional[str] = None,
198205
pad_individual_seqs_to_multiple_of: int = 1,
199206
pad_packed_seq_to_multiple_of: int = 1,
200207
pad_full_seq_to: Optional[int] = None,
201208
pack_sequences: bool = False,
202-
):
203-
#with straggler_timer(bdata=True):
204-
input_ids = data_dict["input_ids"]
205-
attention_mask = None
206-
position_ids = None
207-
packed_seq_params = None
208-
209-
original_batch_size = input_ids.shape[0]
210-
original_seq_length = input_ids.shape[1]
211-
seq_lengths = None # Will be set if using packed sequences
212-
cu_seqlens = None
213-
cu_seqlens_padded = None
214-
215-
if pack_sequences:
216-
# For packed sequences with padded input, we need sequence lengths
217-
assert seq_length_key is not None, (
218-
"seq_length_key must be provided for packed sequences"
219-
)
220-
assert seq_length_key in data_dict, (
221-
f"{seq_length_key} not found in data_dict"
222-
)
209+
straggler_timer: StragglerDetector = None,
210+
) -> tuple[
211+
torch.Tensor,
212+
torch.Tensor,
213+
Optional[torch.Tensor],
214+
Optional[torch.Tensor],
215+
Optional[PackedSeqParams],
216+
Optional[torch.Tensor],
217+
]:
218+
"""Process a microbatch for Megatron model forward pass."""
219+
with straggler_timer(bdata=True):
220+
input_ids = data_dict["input_ids"]
221+
attention_mask = None
222+
position_ids = None
223+
packed_seq_params = None
224+
225+
original_batch_size = input_ids.shape[0]
226+
original_seq_length = input_ids.shape[1]
227+
seq_lengths = None # Will be set if using packed sequences
228+
cu_seqlens = None
229+
cu_seqlens_padded = None
230+
231+
if pack_sequences:
232+
# For packed sequences with padded input, we need sequence lengths
233+
assert seq_length_key is not None, (
234+
"seq_length_key must be provided for packed sequences"
235+
)
236+
assert seq_length_key in data_dict, (
237+
f"{seq_length_key} not found in data_dict"
238+
)
223239

224-
# Get sequence lengths and context parallel size
225-
seq_lengths = data_dict[seq_length_key]
240+
# Get sequence lengths and context parallel size
241+
seq_lengths = data_dict[seq_length_key]
242+
243+
# Pack sequences
244+
(
245+
input_ids,
246+
input_ids_cp_sharded,
247+
packed_seq_params,
248+
cu_seqlens,
249+
cu_seqlens_padded,
250+
) = _pack_sequences_for_megatron(
251+
input_ids,
252+
seq_lengths,
253+
pad_individual_seqs_to_multiple_of,
254+
pad_packed_seq_to_multiple_of,
255+
pad_full_seq_to,
256+
cp_rank=get_context_parallel_rank(),
257+
cp_size=get_context_parallel_world_size(),
258+
)
226259

227-
# Pack sequences
228-
(
229-
input_ids,
230-
input_ids_cp_sharded,
231-
packed_seq_params,
232-
cu_seqlens,
233-
cu_seqlens_padded,
234-
) = _pack_sequences_for_megatron(
235-
input_ids,
236-
seq_lengths,
237-
pad_individual_seqs_to_multiple_of,
238-
pad_packed_seq_to_multiple_of,
239-
pad_full_seq_to,
240-
cp_rank=get_context_parallel_rank(),
241-
cp_size=get_context_parallel_world_size(),
242-
)
243-
244-
# For packed sequences, position_ids and attention_mask are typically None
245-
# The PackedSeqParams handles all necessary sequence information
246-
position_ids = None
247-
attention_mask = None
248-
else:
249-
input_ids_cp_sharded = input_ids
250-
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
251-
data=input_ids,
252-
eod_token=0, # used for loss_mask, which we don't use
253-
pad_token=0, # used for loss_mask, which we don't use
254-
reset_position_ids=False,
255-
reset_attention_mask=False,
256-
eod_mask_loss=False,
257-
pad_mask_loss=False,
258-
)
260+
# For packed sequences, position_ids and attention_mask are typically None
261+
# The PackedSeqParams handles all necessary sequence information
262+
position_ids = None
263+
attention_mask = None
264+
else:
265+
input_ids_cp_sharded = input_ids
266+
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
267+
data=input_ids,
268+
eod_token=0, # used for loss_mask, which we don't use
269+
pad_token=0, # used for loss_mask, which we don't use
270+
reset_position_ids=False,
271+
reset_attention_mask=False,
272+
eod_mask_loss=False,
273+
pad_mask_loss=False,
274+
)
259275
return (
260276
input_ids,
261277
input_ids_cp_sharded,
@@ -265,13 +281,15 @@ def process_microbatch(
265281
cu_seqlens_padded,
266282
)
267283

284+
268285
def process_global_batch(
269286
data: BatchedDataDict[Any],
270287
batch_idx: int,
271288
batch_size: int,
272289
loss_fn: LossFunction,
273290
dp_group: torch.distributed.ProcessGroup,
274-
) -> dict[str, Any]:
291+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
292+
"""Process a global batch for Megatron model forward pass."""
275293
batch = data.get_batch(batch_idx=batch_idx, batch_size=batch_size)
276294

277295
assert "sample_mask" in batch, "sample_mask must be present in the data!"
@@ -301,6 +319,7 @@ def process_global_batch(
301319
global_valid_toks,
302320
)
303321

322+
304323
def _pack_sequences_for_megatron(
305324
input_ids: torch.Tensor,
306325
seq_lengths: torch.Tensor,
@@ -605,13 +624,14 @@ def _unpack_sequences_from_megatron(
605624

606625
return unpacked_output
607626

627+
608628
def check_sequence_dim(data: BatchedDataDict[Any]):
609629
# dim 1 is always assumed to be the sequence dim, sanity check this here
610630
sequence_dim = 1
611631
seq_dim_size = data["input_ids"].shape[sequence_dim]
612632
for k, v in data.items():
613633
if torch.is_tensor(v) and len(v.shape) > 1:
614634
assert v.shape[sequence_dim] == seq_dim_size, (
615-
f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}"
635+
f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape} for key {k}"
616636
)
617-
return sequence_dim, seq_dim_size
637+
return sequence_dim, seq_dim_size

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,11 @@
113113
forward_step_arbitrary_loss,
114114
get_moe_metrics,
115115
)
116+
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
116117
from nemo_rl.models.megatron.data import (
117118
get_microbatch_iterator,
118119
process_global_batch,
119120
)
120-
from nemo_rl.models.megatron.community_import import import_model_from_hf_name
121121
from nemo_rl.models.policy import PolicyConfig
122122
from nemo_rl.models.policy.interfaces import (
123123
ColocatablePolicyInterface,
@@ -1000,7 +1000,12 @@ def train(
10001000
micro_batch_size,
10011001
seq_length,
10021002
padded_seq_length,
1003-
) = get_microbatch_iterator(batch, self.cfg, mbs)
1003+
) = get_microbatch_iterator(
1004+
batch,
1005+
self.cfg,
1006+
mbs,
1007+
straggler_timer=self.mcore_state.straggler_timer,
1008+
)
10041009
# Track total microbatches for MoE aux-loss averaging
10051010
total_num_microbatches += int(num_microbatches)
10061011

@@ -1024,9 +1029,9 @@ def train(
10241029
data_iterator=data_iterator,
10251030
model=self.model,
10261031
num_microbatches=num_microbatches,
1027-
seq_length=seq_dim_size,
1032+
seq_length=padded_seq_length,
10281033
micro_batch_size=mbs,
1029-
decoder_seq_length=seq_dim_size,
1034+
decoder_seq_length=padded_seq_length,
10301035
forward_only=eval_mode,
10311036
do_not_average_loss=True,
10321037
)
@@ -1176,7 +1181,12 @@ def get_logprobs(
11761181
micro_batch_size,
11771182
seq_length,
11781183
padded_seq_length,
1179-
) = get_microbatch_iterator(data, self.cfg, logprob_batch_size)
1184+
) = get_microbatch_iterator(
1185+
data,
1186+
self.cfg,
1187+
logprob_batch_size,
1188+
straggler_timer=self.mcore_state.straggler_timer,
1189+
)
11801190

11811191
def forward_step_fn(
11821192
data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel
@@ -1378,7 +1388,12 @@ def get_topk_logits(
13781388
micro_batch_size,
13791389
seq_length,
13801390
padded_seq_length,
1381-
) = get_microbatch_iterator(data, self.cfg, logprob_batch_size)
1391+
) = get_microbatch_iterator(
1392+
data,
1393+
self.cfg,
1394+
logprob_batch_size,
1395+
straggler_timer=self.mcore_state.straggler_timer,
1396+
)
13821397

13831398
def forward_step_fn(
13841399
data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel

tests/unit/algorithms/test_sequence_packing_gradients.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
339339
"pipeline_model_parallel_size": 1,
340340
"context_parallel_size": cp_size,
341341
},
342-
343-
344342
},
345343
seq_length_key="input_lengths",
346344
pad_individual_seqs_to_multiple_of=pad_to_multiple,

0 commit comments

Comments
 (0)