11# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
22
3+ import enum
34from collections import deque
45from functools import lru_cache
56from math import ceil , log2
6- from typing import Callable , List , Optional , Tuple
7+ from typing import Callable , Dict , List , Optional , Tuple , Type , Union
78
89import numpy as np
910import torch
1314from megatron .core .rerun_state_machine import RerunDataIterator
1415
1516
16- def wrap_hybrid_cp_dataloader (
17- data_iterator , config , pg_collection : Optional [ProcessGroupCollection ] = None
17+ class PackingScheduler (enum .Enum ):
18+ """Enum for supported sequence packing algorithms."""
19+
20+ HYBRID_CP = "hybrid_cp"
21+ NAIVE_SEQUENCE_PACKING = "naive_sequence_packing"
22+
23+
24+ def wrap_dataloader (
25+ data_iterator ,
26+ config ,
27+ scheduler_type : Union [PackingScheduler , str ],
28+ pg_collection : Optional [ProcessGroupCollection ] = None ,
1829):
1930 """
2031 A wrapper function that wraps around an existing data_iterator
@@ -26,6 +37,13 @@ def wrap_hybrid_cp_dataloader(
2637 dp_cp_group: Data parallel context parallel group.
2738 """
2839
40+ scheduler_map = {"hybrid_cp" : BalancedHybridCPscheduler , "naive" : NaiveSequencePackingScheduler }
41+
42+ scheduler_map : Dict [PackingScheduler , Type [BaseScheduler ]] = {
43+ PackingScheduler .HYBRID_CP : BalancedHybridCPscheduler ,
44+ PackingScheduler .NAIVE_SEQUENCE_PACKING : NaiveSequencePackingScheduler ,
45+ }
46+
2947 def _get_global_seqlens (subsample_seqlens : torch .Tensor , dp_group ) -> List [int ]:
3048 """
3149 Gathers the sequence lengths of all subsamples from all DP ranks.
@@ -151,16 +169,17 @@ def _reroute_samples_to_hdp_ranks(
151169 ]
152170 # send_counts = [len(combined_sample_id_groups[d]) for d in range(total_hdp_gpus)]
153171
172+ send_num_split = [0 ] * total_hdp_gpus
154173 send_lens_split = [0 ] * total_hdp_gpus
155174 for dest_rank in range (total_hdp_gpus ):
156175 if dest_rank in dp_ranks :
157- send_lens_split [ dest_rank ] = sum (
158- [
159- global_id_seqlens [ gid ][ 1 ]
160- for gid in combined_sample_id_groups [ dest_rank ]
161- if gid in global_ids_this_rank
162- ]
163- )
176+ send_seq_lens = [
177+ global_id_seqlens [ gid ][ 1 ]
178+ for gid in combined_sample_id_groups [ dest_rank ]
179+ if gid in global_ids_this_rank
180+ ]
181+ send_num_split [ dest_rank ] = len ( send_seq_lens )
182+ send_lens_split [ dest_rank ] = sum ( send_seq_lens )
164183 else :
165184 # We only need to share local data with DP ranks that have different data.
166185 send_lens_split [dest_rank ] = 0
@@ -197,20 +216,30 @@ def _pack_sample_by_key(key: str) -> torch.Tensor:
197216 def _unpack_sample_by_key (key : str , recv_tensor : torch .Tensor ):
198217 cursor = 0
199218 for i , gid in enumerate (recv_ids_sorted ):
200- sample_len = global_id_seqlens [gid ][1 ]
219+ sample_len = 1 if key in [ "original_seq_len" ] else global_id_seqlens [gid ][1 ]
201220 recv_samples [i ][key ] = recv_tensor [cursor : cursor + sample_len ]
202221 cursor += sample_len
203222
204223 for key in data_keys :
224+ output_split_sizes , input_split_sizes = (
225+ (recv_counts , send_num_split )
226+ if key in ["original_seq_len" ]
227+ else (recv_lens_split , send_lens_split )
228+ )
205229 send_tensor = _pack_sample_by_key (key )
230+ recv_tensor_size = sum (output_split_sizes )
206231 recv_tensor = torch .empty (
207- sum ( recv_lens_split ) , device = torch .cuda .current_device (), dtype = send_tensor .dtype
232+ recv_tensor_size , device = torch .cuda .current_device (), dtype = send_tensor .dtype
208233 )
234+ # debugmtl
235+ # print(f"ready to all to all for key:{key}, output_split_sizes:{output_split_sizes},
236+ # input_split_sizes:{input_split_sizes}, recv_tensor_size:
237+ # {tensor_size},send_tensor_size:{send_tensor.size(0)}")
209238 torch .distributed .all_to_all_single (
210239 output = recv_tensor ,
211240 input = send_tensor ,
212- output_split_sizes = recv_lens_split ,
213- input_split_sizes = send_lens_split ,
241+ output_split_sizes = output_split_sizes ,
242+ input_split_sizes = input_split_sizes ,
214243 group = dp_cp_group ,
215244 )
216245 _unpack_sample_by_key (key , recv_tensor )
@@ -245,29 +274,24 @@ def _broadcast(item):
245274 group = parallel_state .get_tensor_model_parallel_group (),
246275 )
247276
248- def _broadcast_num_samples_this_group ( num_samples_this_group ):
249- dev = torch . cuda . current_device ()
250- # TODO(tailaim) do we need this barrier?
251- torch . distributed . barrier ()
252-
253- n = 0 if num_samples_this_group is None else int ( num_samples_this_group . numel () )
254- n = torch . tensor ([ n ], dtype = torch . int64 , device = dev )
255-
256- _broadcast ( n )
257- n = int ( n . item () )
277+ # Convert string to enum if needed
278+ if isinstance ( scheduler_type , str ):
279+ try :
280+ scheduler_type = PackingScheduler [ scheduler_type . upper ()]
281+ except KeyError :
282+ available_scheduler = ", " . join ([ scheduler . name for scheduler in PackingScheduler ] )
283+ raise ValueError (
284+ f"Unknown packing scheduler: { scheduler_type } . "
285+ f"Available schedulers: { available_scheduler } "
286+ )
258287
259- assert n > 0 , "there should be at least 1 sub samples in the group"
260- num_samples_this_group_broadcast = (
261- torch .empty (n , dtype = torch .int32 , device = dev )
262- if num_samples_this_group is None
263- else num_samples_this_group
288+ if scheduler_type not in scheduler_map :
289+ available_scheduler = ", " .join ([scheduler .name for scheduler in PackingScheduler ])
290+ raise ValueError (
291+ f"Unknown scheduler: { scheduler } . " f"Available schedulers: { available_scheduler } "
264292 )
265- _broadcast (num_samples_this_group_broadcast )
266- return num_samples_this_group_broadcast
267293
268- cp_balancing_scheduler = BalancedHybridCPScheduler (
269- max_seq_len_per_rank = config .max_seqlen_per_dp_cp_rank
270- )
294+ scheduler = scheduler_map [scheduler_type ](config )
271295 if pg_collection is None :
272296 dp_cp_group = parallel_state .get_data_parallel_group (with_context_parallel = True )
273297 dp_group = parallel_state .get_data_parallel_group ()
@@ -305,9 +329,7 @@ def _broadcast_num_samples_this_group(num_samples_this_group):
305329 subsample_seqlens .shape [0 ], offsets , seqlens_gathered , dp_group
306330 )
307331
308- groups , sample_id_groups = cp_balancing_scheduler .get_groups_and_subsamples (
309- global_id_seqlens , config
310- )
332+ groups , sample_id_groups = scheduler .get_groups_and_subsamples (global_id_seqlens , config )
311333
312334 batch = _unpack_batch (batch )
313335 samples_this_rank_with_id = _reroute_samples_to_hdp_ranks (
@@ -353,10 +375,10 @@ def _pack_tensors(tensors):
353375
354376 # TODO(tailaim): do we need attention_mask for sequence packing?
355377 new_sample = {}
356- new_sample ["tokens" ] = tokens
357- new_sample ["labels" ] = labels
358- new_sample ["loss_mask" ] = loss_mask
359- new_sample ["position_ids" ] = position_ids
378+ new_sample ["tokens" ] = tokens . unsqueeze ( 0 )
379+ new_sample ["labels" ] = labels . unsqueeze ( 0 )
380+ new_sample ["loss_mask" ] = loss_mask . unsqueeze ( 0 )
381+ new_sample ["position_ids" ] = position_ids . unsqueeze ( 0 )
360382 new_sample ["local_cp_size" ] = torch .tensor (
361383 partner_cp_size , dtype = torch .int32 , device = dev
362384 )
@@ -367,9 +389,11 @@ def _pack_tensors(tensors):
367389 )
368390 cu_seqlens_padded = np .empty (len (samples ) + 1 , dtype = np .int32 )
369391 cu_seqlens_padded [0 ] = 0
370- np .cumsum (lengths_padding , out = cu_seqlens_padded [1 :])
371- cu_seqlens_padded = torch .from_numpy (cu_seqlens_padded ).to (
372- device = dev , non_blocking = True
392+ cu_seqlens_padded [1 :] = np .cumsum (lengths_padding , out = cu_seqlens_padded [1 :])
393+ cu_seqlens_padded = (
394+ torch .from_numpy (cu_seqlens_padded )
395+ .to (device = dev , non_blocking = True , dtype = torch .int32 )
396+ .reshape (- 1 )
373397 )
374398 new_sample ["cu_seqlens_padded" ] = cu_seqlens_padded
375399
@@ -379,27 +403,95 @@ def _pack_tensors(tensors):
379403 new_sample ["max_seqlen" ] = max_seqlen
380404
381405 # create cu_seqlens without padding
382- lengths = torch .stack ([s ["original_seq_len" ] for s in samples ], dim = 0 )
406+ lengths = torch .stack ([s ["original_seq_len" ] for s in samples ], dim = 0 ). reshape ( - 1 )
383407 cu_seqlens = torch .empty (lengths .numel () + 1 , device = dev , dtype = torch .int32 )
384408 cu_seqlens [0 ] = 0
385- cu_seqlens [1 :] = torch .cumsum (lengths , dim = 0 )
409+ cu_seqlens [1 :] = torch .cumsum (lengths , dim = 0 ). reshape ( - 1 )
386410 new_sample ["cu_seqlens" ] = cu_seqlens
387411
388412 new_samples .append (new_sample )
389-
413+ # debugmtl
414+ # print(f"new_samples type: {type(new_samples)}, new_sample type: {type(new_samples[0])}")
390415 new_data_iterator = RerunDataIterator (iter (new_samples ))
391416
417+ # debugmtl
418+ # data = next(new_data_iterator)
419+ # print(f"data type: {type(data)}")
420+ # print(data)
421+
392422 return new_data_iterator , num_micro_batches
393423
394424
395- class BalancedHybridCPScheduler :
425+ class BaseScheduler :
426+ """
427+ Base class for sequence packing schedulers.
428+ """
429+
430+ def __init__ (self , config ):
431+ pass
432+
433+
434+ class NaiveSequencePackingScheduler (BaseScheduler ):
435+ """
436+ This scheduler simply packs sequences in their original order
437+ until reaching the max sequence length.
438+ It does not reorder sequences nor perform any load balancing.
439+ """
440+
441+ def __init__ (self , config ):
442+ super ().__init__ (config )
443+ self .max_seq_len_all_ranks = config .max_seqlen_per_dp_cp_rank * config .context_parallel_size
444+ self .dp_size = parallel_state .get_data_parallel_world_size ()
445+
446+ def get_groups_and_subsamples (self , sample_id_seqlens , config ):
447+ """
448+ This scheduler simply packs sequences in their original order
449+ until reaching the max sequence length.
450+ It does not reorder sequences nor perform any load balancing.
451+ """
452+ groups = []
453+ sample_id_groups = []
454+ sum_seqlen = 0
455+ single_microbatch = []
456+
457+ for i in range (len (sample_id_seqlens )):
458+ if sum_seqlen + sample_id_seqlens [i ] <= self .max_seq_len_all_ranks :
459+ single_microbatch .append (i )
460+ sum_seqlen += sample_id_seqlens [i ][1 ]
461+ else :
462+ groups .append (single_microbatch )
463+ sample_id_groups .append (single_microbatch )
464+ single_microbatch = [i ]
465+ sum_seqlen = sample_id_seqlens [i ][1 ]
466+
467+ # we want the number of microbatches to be multiple of dp_size
468+ # so we move few samples from previous microbatch
469+ # to the end of the microbatches if needed
470+ num_microbatches_before = len (sample_id_groups )
471+ if num_microbatches_before % self .dp_size != 0 :
472+ remainder = num_microbatches_before % self .dp_size
473+ num_to_move = self .dp_size - remainder
474+ i = num_microbatches_before - 1
475+ while num_to_move > 0 :
476+ assert i > 0 , "Not enough samples to move"
477+ if len (sample_id_groups [i ]) > 1 :
478+ seq_id = sample_id_groups [i ].pop ()
479+ sample_id_groups [i ].append (seq_id )
480+ num_to_move -= 1
481+ else :
482+ i -= 1
483+ return groups , sample_id_groups
484+
485+
486+ class BalancedHybridCPscheduler (BaseScheduler ):
396487 """
397488 This class provides the functionality to form groups of sub-samples
398489 such that all DPxCP ranks have a roughly balanced workload in the group.
399490 """
400491
401- def __init__ (self , max_seq_len_per_rank : int ):
402- self .max_seq_len_per_rank = max_seq_len_per_rank
492+ def __init__ (self , config ):
493+ super ().__init__ (config )
494+ self .max_seq_len_per_rank = config .max_seqlen_per_dp_cp_rank
403495 self .num_subsamples = 0
404496 self .num_subsamples_processed = 0
405497 self .free_resources = []
@@ -614,6 +706,8 @@ def next_hdp_group(
614706 else :
615707 chosen_members = group_members [best_gid ]
616708 else :
709+ if best_gid is None :
710+ break
617711 chosen_members = group_members [best_gid ]
618712
619713 # ---- Step 2 – if we decided to create a fresh group ----------------
@@ -731,7 +825,6 @@ def trim_overload():
731825 else :
732826 break
733827
734- # debugmtl make sure total_seq_len after packing smaller than max_seq_len
735828 # trim_overload()
736829
737830 # Track samples in this group before redistribution to empty GPUs
0 commit comments