1616from typing import Any , Iterator , Optional , Tuple
1717
1818import torch
19-
2019from megatron .core .packed_seq_params import PackedSeqParams
2120from megatron .core .parallel_state import (
2221 get_context_parallel_rank ,
2322 get_context_parallel_world_size ,
2423)
24+ from megatron .core .utils import StragglerDetector
2525from megatron .training .utils import get_ltor_masks_and_position_ids
26- from nemo_rl . models . megatron . common import _round_up_to_multiple
26+
2727from nemo_rl .algorithms .interfaces import LossFunction , LossType
2828from nemo_rl .distributed .batched_data_dict import BatchedDataDict
2929from nemo_rl .distributed .model_utils import _get_tokens_on_this_cp_rank
30+ from nemo_rl .models .megatron .common import _round_up_to_multiple
3031
3132
3233@dataclass
@@ -45,6 +46,7 @@ class ProcessedMicrobatch:
4546 packed_seq_params: PackedSeqParams for sequence packing (None if not packing)
4647 cu_seqlens_padded: Padded cumulative sequence lengths (None if not packing)
4748 """
49+
4850 data_dict : BatchedDataDict [Any ]
4951 input_ids : torch .Tensor
5052 input_ids_cp_sharded : torch .Tensor
@@ -60,6 +62,7 @@ def make_processed_microbatch_iterator(
6062 seq_length_key : Optional [str ],
6163 pad_individual_seqs_to_multiple_of : int ,
6264 pad_packed_seq_to_multiple_of : int ,
65+ straggler_timer : StragglerDetector ,
6366 pad_full_seq_to : Optional [int ],
6467) -> Iterator [ProcessedMicrobatch ]:
6568 """Wrap a raw microbatch iterator to yield processed microbatches.
@@ -100,6 +103,7 @@ def make_processed_microbatch_iterator(
100103 pad_packed_seq_to_multiple_of = pad_packed_seq_to_multiple_of ,
101104 pad_full_seq_to = pad_full_seq_to ,
102105 pack_sequences = pack_sequences ,
106+ straggler_timer = straggler_timer ,
103107 )
104108
105109 yield ProcessedMicrobatch (
@@ -117,6 +121,7 @@ def get_microbatch_iterator(
117121 data : BatchedDataDict [Any ],
118122 cfg : dict [str , Any ],
119123 mbs : int ,
124+ straggler_timer : StragglerDetector ,
120125 seq_length_key : Optional [str ] = None ,
121126) -> Tuple [Iterator [ProcessedMicrobatch ], int , int , int , int ]:
122127 """Create a processed microbatch iterator from a batch of data.
@@ -179,6 +184,7 @@ def get_microbatch_iterator(
179184 pad_individual_seqs_to_multiple_of = pad_factor ,
180185 pad_packed_seq_to_multiple_of = pad_packed_seq_to_multiple_of ,
181186 pad_full_seq_to = pad_full_seq_to ,
187+ straggler_timer = straggler_timer ,
182188 )
183189
184190 # Compute padded sequence length for pipeline parallelism
@@ -192,70 +198,80 @@ def get_microbatch_iterator(
192198 padded_seq_length ,
193199 )
194200
201+
195202def process_microbatch (
196203 data_dict : BatchedDataDict [Any ],
197204 seq_length_key : Optional [str ] = None ,
198205 pad_individual_seqs_to_multiple_of : int = 1 ,
199206 pad_packed_seq_to_multiple_of : int = 1 ,
200207 pad_full_seq_to : Optional [int ] = None ,
201208 pack_sequences : bool = False ,
202- ):
203- #with straggler_timer(bdata=True):
204- input_ids = data_dict ["input_ids" ]
205- attention_mask = None
206- position_ids = None
207- packed_seq_params = None
208-
209- original_batch_size = input_ids .shape [0 ]
210- original_seq_length = input_ids .shape [1 ]
211- seq_lengths = None # Will be set if using packed sequences
212- cu_seqlens = None
213- cu_seqlens_padded = None
214-
215- if pack_sequences :
216- # For packed sequences with padded input, we need sequence lengths
217- assert seq_length_key is not None , (
218- "seq_length_key must be provided for packed sequences"
219- )
220- assert seq_length_key in data_dict , (
221- f"{ seq_length_key } not found in data_dict"
222- )
209+ straggler_timer : StragglerDetector = None ,
210+ ) -> tuple [
211+ torch .Tensor ,
212+ torch .Tensor ,
213+ Optional [torch .Tensor ],
214+ Optional [torch .Tensor ],
215+ Optional [PackedSeqParams ],
216+ Optional [torch .Tensor ],
217+ ]:
218+ """Process a microbatch for Megatron model forward pass."""
219+ with straggler_timer (bdata = True ):
220+ input_ids = data_dict ["input_ids" ]
221+ attention_mask = None
222+ position_ids = None
223+ packed_seq_params = None
224+
225+ original_batch_size = input_ids .shape [0 ]
226+ original_seq_length = input_ids .shape [1 ]
227+ seq_lengths = None # Will be set if using packed sequences
228+ cu_seqlens = None
229+ cu_seqlens_padded = None
230+
231+ if pack_sequences :
232+ # For packed sequences with padded input, we need sequence lengths
233+ assert seq_length_key is not None , (
234+ "seq_length_key must be provided for packed sequences"
235+ )
236+ assert seq_length_key in data_dict , (
237+ f"{ seq_length_key } not found in data_dict"
238+ )
223239
224- # Get sequence lengths and context parallel size
225- seq_lengths = data_dict [seq_length_key ]
240+ # Get sequence lengths and context parallel size
241+ seq_lengths = data_dict [seq_length_key ]
242+
243+ # Pack sequences
244+ (
245+ input_ids ,
246+ input_ids_cp_sharded ,
247+ packed_seq_params ,
248+ cu_seqlens ,
249+ cu_seqlens_padded ,
250+ ) = _pack_sequences_for_megatron (
251+ input_ids ,
252+ seq_lengths ,
253+ pad_individual_seqs_to_multiple_of ,
254+ pad_packed_seq_to_multiple_of ,
255+ pad_full_seq_to ,
256+ cp_rank = get_context_parallel_rank (),
257+ cp_size = get_context_parallel_world_size (),
258+ )
226259
227- # Pack sequences
228- (
229- input_ids ,
230- input_ids_cp_sharded ,
231- packed_seq_params ,
232- cu_seqlens ,
233- cu_seqlens_padded ,
234- ) = _pack_sequences_for_megatron (
235- input_ids ,
236- seq_lengths ,
237- pad_individual_seqs_to_multiple_of ,
238- pad_packed_seq_to_multiple_of ,
239- pad_full_seq_to ,
240- cp_rank = get_context_parallel_rank (),
241- cp_size = get_context_parallel_world_size (),
242- )
243-
244- # For packed sequences, position_ids and attention_mask are typically None
245- # The PackedSeqParams handles all necessary sequence information
246- position_ids = None
247- attention_mask = None
248- else :
249- input_ids_cp_sharded = input_ids
250- attention_mask , _ , position_ids = get_ltor_masks_and_position_ids (
251- data = input_ids ,
252- eod_token = 0 , # used for loss_mask, which we don't use
253- pad_token = 0 , # used for loss_mask, which we don't use
254- reset_position_ids = False ,
255- reset_attention_mask = False ,
256- eod_mask_loss = False ,
257- pad_mask_loss = False ,
258- )
260+ # For packed sequences, position_ids and attention_mask are typically None
261+ # The PackedSeqParams handles all necessary sequence information
262+ position_ids = None
263+ attention_mask = None
264+ else :
265+ input_ids_cp_sharded = input_ids
266+ attention_mask , _ , position_ids = get_ltor_masks_and_position_ids (
267+ data = input_ids ,
268+ eod_token = 0 , # used for loss_mask, which we don't use
269+ pad_token = 0 , # used for loss_mask, which we don't use
270+ reset_position_ids = False ,
271+ reset_attention_mask = False ,
272+ eod_mask_loss = False ,
273+ pad_mask_loss = False ,
274+ )
259275 return (
260276 input_ids ,
261277 input_ids_cp_sharded ,
@@ -265,13 +281,15 @@ def process_microbatch(
265281 cu_seqlens_padded ,
266282 )
267283
284+
268285def process_global_batch (
269286 data : BatchedDataDict [Any ],
270287 batch_idx : int ,
271288 batch_size : int ,
272289 loss_fn : LossFunction ,
273290 dp_group : torch .distributed .ProcessGroup ,
274- ) -> dict [str , Any ]:
291+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
292+ """Process a global batch for Megatron model forward pass."""
275293 batch = data .get_batch (batch_idx = batch_idx , batch_size = batch_size )
276294
277295 assert "sample_mask" in batch , "sample_mask must be present in the data!"
@@ -301,6 +319,7 @@ def process_global_batch(
301319 global_valid_toks ,
302320 )
303321
322+
304323def _pack_sequences_for_megatron (
305324 input_ids : torch .Tensor ,
306325 seq_lengths : torch .Tensor ,
@@ -605,13 +624,14 @@ def _unpack_sequences_from_megatron(
605624
606625 return unpacked_output
607626
627+
608628def check_sequence_dim (data : BatchedDataDict [Any ]):
609629 # dim 1 is always assumed to be the sequence dim, sanity check this here
610630 sequence_dim = 1
611631 seq_dim_size = data ["input_ids" ].shape [sequence_dim ]
612632 for k , v in data .items ():
613633 if torch .is_tensor (v ) and len (v .shape ) > 1 :
614634 assert v .shape [sequence_dim ] == seq_dim_size , (
615- f"Dim 1 must be the sequence dim, expected dim 1={ seq_dim_size } but got shape { v .shape } "
635+ f"Dim 1 must be the sequence dim, expected dim 1={ seq_dim_size } but got shape { v .shape } for key { k } "
616636 )
617- return sequence_dim , seq_dim_size
637+ return sequence_dim , seq_dim_size
0 commit comments