-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Expand file tree
/
Copy pathdata_schedule_utils.py
More file actions
1030 lines (865 loc) · 39.8 KB
/
data_schedule_utils.py
File metadata and controls
1030 lines (865 loc) · 39.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.
from collections import deque
from functools import lru_cache
from math import ceil, log2
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from megatron.core.extensions.transformer_engine import get_thd_partitioned_indices
from megatron.core.rerun_state_machine import RerunDataIterator
def get_cp_slice_for_thd(batch, cp_group):
"""Partition sequence data for context parallelism in THD format.
Uses TE's THD partitioned indices to split the packed sequence across CP ranks.
Only keys present in the batch are sliced.
Args:
batch: Dict with packed sequence data.
cp_group: Context parallel process group.
"""
cp_size = cp_group.size()
if cp_size <= 1:
return
cp_rank = cp_group.rank()
total_tokens = batch['tokens'].size(0)
# Transformer Engine has a bug of cu_seqlens, we must treat cu_seqlens_padded as
# cu_seqlens to get the correct result.
# TODO: Revert this workaround once TE fixes the issue.
cu_seqlens = batch["cu_seqlens_padded"]
index = get_thd_partitioned_indices(cu_seqlens, total_tokens, cp_size, cp_rank)
for key in ['tokens', 'position_ids', 'labels', 'loss_mask']:
if key in batch:
batch[key] = batch[key].index_select(0, index)
def _unpack_batch(batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]:
"""
Unpacks the packed samples into a list of sub-samples.
Since each sub-sample may be routed to different DPxCP ranks,
we unpack the sample here to avoid unnecessarily transferring
the entire packed sample.
"""
batch_unpacked = []
dev = batch[0]["tokens"].device
original_seq_lens = []
padded_seq_lens = []
for sample in batch:
for key in sample.keys():
if len(sample[key].shape) == 2:
# squeeze the redundant batch dimension added by
# default collate_fn in pytorch dataloader
# we need a custom collate_fn for THD to avoid this
# current THD does not support micro_batch_size > 1 due to sft_dataset.py and
# data_loader in data_samples.py
sample[key] = sample[key].squeeze(0)
for sub_sample in range(sample["cu_seqlens"].shape[0] - 1):
sub_sample_dict = {}
start_idx = sample["cu_seqlens"][sub_sample]
end_idx = sample["cu_seqlens"][sub_sample + 1]
if end_idx - start_idx == 0:
continue
for key in ["tokens", "labels", "loss_mask", "position_ids"]:
sub_sample_dict[key] = sample[key][start_idx:end_idx]
# Since sft_dataset.py does not provide cu_seqlens_original,
# we assume original_seq_len equals padded_seq_len here.
# Ideally the dataset should define the pre-padding seq_len.
seq_len = (end_idx - start_idx).item()
original_seq_lens.append(seq_len)
padded_seq_lens.append(seq_len)
batch_unpacked.append(sub_sample_dict)
# Single H2D transfer for all seq lens
original_seq_lens_cuda = torch.tensor(original_seq_lens, device=dev)
padded_seq_lens_cuda = torch.tensor(padded_seq_lens, device=dev)
for i, sub_sample_dict in enumerate(batch_unpacked):
sub_sample_dict["original_seq_len"] = original_seq_lens_cuda[i : i + 1]
sub_sample_dict["padded_seq_len"] = padded_seq_lens_cuda[i : i + 1]
return batch_unpacked
def _get_global_seqlens_and_ids(subsample_seqlens: torch.Tensor, dp_group):
"""
Gathers the sequence lengths of all subsamples from all DP ranks and calculates global IDs.
"""
# Collect the number of subsamples from all ranks
num_local_subsamples = subsample_seqlens.shape[0]
local_len = torch.tensor([num_local_subsamples], dtype=torch.int32).cuda()
dp_subsample_count = [torch.zeros_like(local_len) for _ in range(dp_group.size())]
torch.distributed.all_gather(dp_subsample_count, local_len, group=dp_group)
# Find the max number of subsamples across all ranks and pad subsample_seqlens to max length
dp_subsample_counts = torch.stack(dp_subsample_count, dim=0).cpu().view(-1)
max_sub_samples = int(dp_subsample_counts.max().item())
if num_local_subsamples < max_sub_samples:
subsample_seqlens_padded = torch.cat(
[
subsample_seqlens,
torch.zeros(max_sub_samples - num_local_subsamples, dtype=torch.int32).cuda(),
],
dim=0,
)
else:
subsample_seqlens_padded = subsample_seqlens
# Gather the subsample_seqlens from all ranks
seqlens_gathered = [torch.empty_like(subsample_seqlens_padded) for _ in range(dp_group.size())]
torch.distributed.all_gather(seqlens_gathered, subsample_seqlens_padded, group=dp_group)
# Trim each seqlens_gathered to the length of the correct sample
for dp_rank, seqlen in enumerate(seqlens_gathered):
seqlens_gathered[dp_rank] = seqlen[: dp_subsample_counts[dp_rank]]
seqlens_gathered = torch.cat(seqlens_gathered, dim=0)
seqlens_gathered = seqlens_gathered.cpu().tolist()
# Calculate the offsets to assign unique global ID to each subsample.
csum = torch.cumsum(dp_subsample_counts, dim=0, dtype=torch.int32)
offsets = torch.cat([torch.zeros(1, dtype=torch.int32), csum], dim=0)
# Calculate global ID for each subsample
dp_rank = dp_group.rank()
global_ids = torch.arange(len(seqlens_gathered), dtype=torch.int32).cuda()
# Create a list of (global_id, seqlen) tuples for scheduling
global_id_seqlens = [(i, seqlens_gathered[i]) for i in range(len(global_ids))]
# Get the global IDs locally present on this rank
start_idx = offsets[dp_rank]
end_idx = offsets[dp_rank + 1]
global_ids_this_rank = global_ids[start_idx:end_idx]
return global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered
def _pack_sequences(
samples: List,
padded_lengths: torch.Tensor,
original_lengths: torch.Tensor,
local_cp_size: Optional[torch.Tensor],
dev: torch.device,
) -> Dict[str, torch.Tensor]:
"""Pack multiple samples into a single packed sample."""
def _pack_tensors(tensors):
return torch.cat([t.reshape(-1) for t in tensors], dim=0)
tokens = _pack_tensors([sample["tokens"] for sample in samples])
labels = _pack_tensors([sample["labels"] for sample in samples])
loss_mask = _pack_tensors([sample["loss_mask"] for sample in samples])
position_ids = _pack_tensors([sample["position_ids"] for sample in samples])
new_sample = {}
new_sample["tokens"] = tokens
new_sample["labels"] = labels
new_sample["loss_mask"] = loss_mask
new_sample["position_ids"] = position_ids
padded_lengths = padded_lengths.to(device=dev, dtype=torch.int32, non_blocking=True).reshape(-1)
cu_seqlens_padded = torch.empty(padded_lengths.numel() + 1, device=dev, dtype=torch.int32)
cu_seqlens_padded[0] = 0
cu_seqlens_padded[1:] = torch.cumsum(padded_lengths, dim=0)
max_seqlen = torch.max(padded_lengths).to(dtype=torch.int32)
new_sample["cu_seqlens_padded"] = cu_seqlens_padded
new_sample["max_seqlen"] = max_seqlen
original_lengths = original_lengths.to(
device=dev, dtype=torch.int32, non_blocking=True
).reshape(-1)
cu_seqlens = torch.empty(original_lengths.numel() + 1, device=dev, dtype=torch.int32)
cu_seqlens[0] = 0
cu_seqlens[1:] = torch.cumsum(original_lengths, dim=0).reshape(-1)
new_sample["cu_seqlens"] = cu_seqlens
if local_cp_size is not None:
new_sample["local_cp_size"] = local_cp_size
return new_sample
def broadcast_tensor(item, src_rank, group) -> None:
"""Broadcast a tensor from src_rank to all ranks in the group."""
if item is not None:
torch.distributed.broadcast(item, src_rank, group=group)
def broadcast_to_pp_group(
new_samples,
num_micro_batches,
seqlen_sum_this_global_batch,
seqlen_squared_sum_this_global_batch,
pp_group,
dev,
is_dynamic_cp: bool = False,
):
"""
Broadcast num_micro_batches, seqlen_sum_this_global_batch,
seqlen_squared_sum_this_global_batch and metadata to middle PP stages.
Before this broadcast, the new_samples on middle PP stages are None,
after this broadcast, the new_samples on middle PP stages contain the metadata but
without tokens, labels, loss_mask, position_ids.
"""
pp_src_rank = torch.distributed.get_process_group_ranks(pp_group)[0]
if pp_group.size() > 2:
if pp_group.rank() == 0:
tensor_list = [
torch.tensor(
[
num_micro_batches,
seqlen_sum_this_global_batch,
seqlen_squared_sum_this_global_batch,
],
dtype=torch.float32,
).cuda()
]
for sample in new_samples:
tensor_list.append(sample["max_seqlen"].unsqueeze(0))
if is_dynamic_cp:
for sample in new_samples:
tensor_list.append(sample["local_cp_size"].unsqueeze(0))
for sample in new_samples:
tensor_list.append(sample["cu_seqlens"])
tensor_list.append(sample["cu_seqlens_padded"])
info_to_broadcast = torch.cat(tensor_list, dim=0).to(device=dev, dtype=torch.float32)
info_length_tensor = torch.tensor(info_to_broadcast.shape[0], dtype=torch.int32).cuda()
broadcast_tensor(info_length_tensor, pp_src_rank, pp_group)
broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group)
else:
info_length_tensor = torch.tensor(0, dtype=torch.int32).cuda()
broadcast_tensor(info_length_tensor, pp_src_rank, pp_group)
info_to_broadcast = torch.empty(info_length_tensor.item(), dtype=torch.float32).cuda()
broadcast_tensor(info_to_broadcast, pp_src_rank, pp_group)
if pp_group.rank() != pp_group.size() - 1:
# middle PP stages receive the broadcasted info and unpack it
info_numpy = info_to_broadcast.cpu().numpy()
num_micro_batches = int(info_numpy[0])
seqlen_sum_this_global_batch = info_numpy[1]
seqlen_squared_sum_this_global_batch = info_numpy[2]
max_seqlens = info_to_broadcast[3 : 3 + num_micro_batches]
local_cp_sizes = (
info_to_broadcast[3 + num_micro_batches : 3 + 2 * num_micro_batches]
if is_dynamic_cp
else None
)
cu_seqlens_list = []
cu_seqlens_padded_list = []
# cu_seqlens always starts with 0, and the other metadata values
# (num_micro_batches, seqlen_sum, seqlen_squared_sum, max_seqlens)
# are always positive, so we can use 0 as the delimiter to locate
# the start of each cu_seqlens / cu_seqlens_padded tensor.
# This avoids an extra broadcast for the lengths of cu_seqlens.
indices = np.where(info_numpy == 0)[0]
for i in range(num_micro_batches):
cu_seqlens_list.append(info_to_broadcast[indices[i * 2] : indices[i * 2 + 1]])
if i == num_micro_batches - 1:
cu_seqlens_padded_list.append(info_to_broadcast[indices[i * 2 + 1] :])
else:
cu_seqlens_padded_list.append(
info_to_broadcast[indices[i * 2 + 1] : indices[i * 2 + 2]]
)
new_samples = []
for i in range(num_micro_batches):
new_sample = {}
new_sample["max_seqlen"] = max_seqlens[i].to(torch.int32)
new_sample["cu_seqlens"] = cu_seqlens_list[i].to(torch.int32)
new_sample["cu_seqlens_padded"] = cu_seqlens_padded_list[i].to(torch.int32)
if is_dynamic_cp:
new_sample["local_cp_size"] = local_cp_sizes[i].to(torch.int32)
new_samples.append(new_sample)
return (
new_samples,
num_micro_batches,
seqlen_sum_this_global_batch,
seqlen_squared_sum_this_global_batch,
)
def broadcast_scalars(values: List, group, dev, dtype=torch.float32) -> List:
"""
Broadcast scalar values from rank 0 to all ranks in the group.
Args:
values: List of scalar values to broadcast (only used on rank 0).
group: The process group to broadcast within.
dev: The device to use for the tensor.
dtype: The data type for the tensor.
Returns:
List of broadcasted values.
"""
if group.size() <= 1:
return values
src_rank = torch.distributed.get_process_group_ranks(group)[0]
num_values = len(values)
if group.rank() == 0:
info_to_broadcast = torch.tensor(values, dtype=dtype, device=dev)
else:
info_to_broadcast = torch.zeros(num_values, dtype=dtype, device=dev)
broadcast_tensor(info_to_broadcast, src_rank, group)
if group.rank() != 0:
values = info_to_broadcast.cpu().tolist()
return values
def create_data_iterator(
new_samples, tp_group, config, vpp_has_data=None, is_dynamic_cp: bool = False
):
"""Handle virtual pipeline parallelism.
For VPP, each PP rank needs a list of data iterators (one per VPP stage).
VPP stages that originally had a data_iterator (indicated by vpp_has_data)
get full samples; others get metadata only (cu_seqlens, cu_seqlens_padded,
max_seqlen).
Args:
new_samples: The packed samples after scheduling.
tp_group: Tensor parallel process group.
config: Model parallel config.
vpp_has_data: A list of booleans (one per VPP stage) indicating which
VPP stages originally had a data_iterator. None if VPP is disabled.
"""
if (
config.virtual_pipeline_model_parallel_size is not None
and config.virtual_pipeline_model_parallel_size > 1
):
vpp_size = config.virtual_pipeline_model_parallel_size
if tp_group.rank() == 0:
metadata_keys = ["max_seqlen", "cu_seqlens", "cu_seqlens_padded"]
if is_dynamic_cp:
metadata_keys.append("local_cp_size")
metadata = [
{k: sample[k] for k in metadata_keys if k in sample} for sample in new_samples
]
new_data_iterator = []
for i in range(vpp_size):
if vpp_has_data is not None and vpp_has_data[i]:
new_data_iterator.append(RerunDataIterator(iter(new_samples)))
else:
new_data_iterator.append(RerunDataIterator(iter(metadata)))
else:
new_data_iterator = [None for _ in range(vpp_size)]
else:
new_data_iterator = RerunDataIterator(iter(new_samples)) if tp_group.rank() == 0 else None
return new_data_iterator
def reroute_samples_to_dcp_ranks(
batch,
global_ids_this_rank,
global_id_seqlens,
sample_id_groups,
offsets,
dp_group,
tp_group,
dp_cp_group,
total_dcp_gpus,
):
"""
Reroutes the sub-samples to the correct rank after scheduling.
For each key in the batch dict, we perform an all-to-all communication
to transfer the data to the correct ranks.
"""
def _gid_to_src_rank(gid: int) -> int:
dp_src_rank = torch.bucketize(gid, offsets[1:] - 1)
dcp_rank = (
torch.distributed.get_process_group_ranks(dp_group)[dp_src_rank] // tp_group.size()
) % dp_cp_group.size()
return dcp_rank
gid2local_id = {int(gid): i for i, gid in enumerate(global_ids_this_rank)}
dcp_rank = dp_cp_group.rank()
dp_ranks = torch.distributed.get_process_group_ranks(dp_group)
dp_ranks = [(r // tp_group.size()) % dp_cp_group.size() for r in dp_ranks]
data_keys = batch[0].keys()
# Create the send plan
combined_sample_id_groups: List[List[int]] = [[] for _ in range(total_dcp_gpus)]
for d in range(total_dcp_gpus):
for sample_id_group in sample_id_groups:
combined_sample_id_groups[d].extend(sample_id_group[d])
for dest_rank in range(total_dcp_gpus):
combined_sample_id_groups[dest_rank].sort()
send_ids_sorted = [
gid for d in dp_ranks for gid in combined_sample_id_groups[d] if gid in global_ids_this_rank
]
send_num_split = [0] * total_dcp_gpus
send_lens_split = [0] * total_dcp_gpus
for dest_rank in range(total_dcp_gpus):
if dest_rank in dp_ranks:
send_seq_lens = [
global_id_seqlens[gid][1]
for gid in combined_sample_id_groups[dest_rank]
if gid in global_ids_this_rank
]
send_num_split[dest_rank] = len(send_seq_lens)
send_lens_split[dest_rank] = sum(send_seq_lens)
else:
send_lens_split[dest_rank] = 0
# Create the recv plan
recv_sample_id_groups = [[] for _ in range(total_dcp_gpus)]
for gid in combined_sample_id_groups[dcp_rank]:
src_rank = _gid_to_src_rank(gid)
recv_sample_id_groups[src_rank].append(gid)
recv_lens_split = [0] * total_dcp_gpus
for src_rank in range(total_dcp_gpus):
recv_lens_split[src_rank] = sum(
[global_id_seqlens[gid][1] for gid in recv_sample_id_groups[src_rank]]
)
recv_ids_sorted = [gid for d in range(total_dcp_gpus) for gid in recv_sample_id_groups[d]]
recv_counts = [len(recv_sample_id_groups[d]) for d in range(total_dcp_gpus)]
recv_samples = [{k: None for k in data_keys} for _ in range(sum(recv_counts))]
def _pack_sample_by_key(key: str) -> torch.Tensor:
flattened_tensors = []
for gid in send_ids_sorted:
t = batch[gid2local_id[gid]][key].to(torch.cuda.current_device(), non_blocking=True)
flattened_tensors.append(t.reshape(-1))
return (
torch.cat(flattened_tensors, dim=0)
if flattened_tensors
else torch.empty(1, device=torch.cuda.current_device(), dtype=batch[0][key].dtype)
)
def _unpack_sample_by_key(key: str, recv_tensor: torch.Tensor):
cursor = 0
for i, gid in enumerate(recv_ids_sorted):
sample_len = (
1 if key in ["original_seq_len", "padded_seq_len"] else global_id_seqlens[gid][1]
)
recv_samples[i][key] = recv_tensor[cursor : cursor + sample_len]
cursor += sample_len
for key in data_keys:
output_split_sizes, input_split_sizes = (
(recv_counts, send_num_split)
if key in ["original_seq_len", "padded_seq_len"]
else (recv_lens_split, send_lens_split)
)
send_tensor = _pack_sample_by_key(key)
recv_tensor_size = sum(output_split_sizes)
recv_tensor = torch.empty(
recv_tensor_size, device=torch.cuda.current_device(), dtype=send_tensor.dtype
)
torch.distributed.all_to_all_single(
output=recv_tensor,
input=send_tensor,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=dp_cp_group,
)
_unpack_sample_by_key(key, recv_tensor)
recv_sample_with_id = {recv_id: recv_samples[i] for i, recv_id in enumerate(recv_ids_sorted)}
return recv_sample_with_id
def build_packed_microbatches(
samples_this_rank_with_id: Dict[int, Dict[str, torch.Tensor]],
sample_id_groups: List[List[List[int]]],
dcp_rank: int,
dev: torch.device,
is_dynamic_cp: bool = False,
) -> List[Dict[str, torch.Tensor]]:
"""Build packed samples for each microbatch.
Args:
samples_this_rank_with_id: Mapping from global sample ID to sample dict,
as returned by reroute_samples_to_dcp_ranks.
sample_id_groups: Per-microbatch, per-rank lists of global sample IDs.
dcp_rank: This rank's index within the DP×CP group.
dev: Target device.
is_dynamic_cp: Whether dynamic context parallel is enabled.
"""
num_micro_batches = len(sample_id_groups)
seg_starts: List[int] = [0]
original_lens_tensors = []
padded_lens_tensors = []
grouped_samples = [
[
samples_this_rank_with_id[sub_sample_id]
for sub_sample_id in sample_id_groups[i][dcp_rank]
]
for i in range(num_micro_batches)
]
local_cp_sizes_gpu = None
if is_dynamic_cp:
local_cp_sizes_cpu: List[int] = []
for i in range(num_micro_batches):
sample_ids_this_group = sample_id_groups[i][dcp_rank]
local_cp_sizes_cpu.append(
len(
[
1
for sample_ids in sample_id_groups[i]
if sample_ids_this_group[0] in sample_ids
]
)
)
local_cp_sizes_gpu = torch.tensor(local_cp_sizes_cpu, dtype=torch.int32, device=dev)
for i in range(num_micro_batches):
samples = grouped_samples[i]
seg_starts.append(seg_starts[-1] + len(samples))
original_lens_tensors.extend([s["original_seq_len"].reshape(-1) for s in samples])
padded_lens_tensors.extend([s["padded_seq_len"].reshape(-1) for s in samples])
padded_lens_all_gpu = torch.cat(padded_lens_tensors, dim=0).to(dtype=torch.int32)
original_lens_all_gpu = torch.cat(original_lens_tensors, dim=0).to(dtype=torch.int32)
new_samples: List[Dict[str, torch.Tensor]] = []
for i in range(num_micro_batches):
samples = grouped_samples[i]
lens_padded = padded_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]]
lens_original = original_lens_all_gpu[seg_starts[i] : seg_starts[i + 1]]
local_cp_size = local_cp_sizes_gpu[i] if is_dynamic_cp else None
new_sample = _pack_sequences(samples, lens_padded, lens_original, local_cp_size, dev)
new_samples.append(new_sample)
return new_samples
def get_batch_and_global_seqlens(data_iterator, num_microbatches, dp_group):
"""
Get the batch and global sequence lengths.
Each DP rank loads the same number of sequences, so we need to gather the sequence
lengths from all ranks then we can schedule the sequences into groups.
Args:
data_iterator: The data iterator.
num_microbatches: The number of microbatches.
dp_group: The data parallel group.
Returns:
batch: The batch.
global_id_seqlens: The global sequence lengths.
global_ids_this_rank: The global IDs locally present on this rank.
"""
batch_list = [next(data_iterator) for _ in range(num_microbatches)]
batch = []
for item in batch_list:
if isinstance(item, dict):
batch.append(item)
elif isinstance(item, list):
batch.extend(item)
else:
raise ValueError(f"Invalid item type: {type(item)}")
# in sft_dataset.py, sequences are already packed before rescheduling,
# so we need to unpack them here and repack after rescheduling.
# This is only to adapt to the current megatron-lm sft_dataset.
# If you implement your own dataset, just have __getitem__ return List[Dict]
# and this step can be skipped.
batch = _unpack_batch(batch)
subsample_seqlens = torch.cat([sample["padded_seq_len"] for sample in batch]).to(
dtype=torch.int32, device=torch.cuda.current_device()
)
global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered = (
_get_global_seqlens_and_ids(subsample_seqlens, dp_group)
)
return batch, global_id_seqlens, global_ids_this_rank, offsets, seqlens_gathered
# =============================================================================
# Dynamic CP scheduling algorithms (used by DefaultDynamicCPScheduler)
# =============================================================================
def next_hdp_group(
sample_seqlens: List[Tuple[int, int]],
compute_estimator: Callable[[int], float],
total_gpus: int,
gpus_needed_fn: Callable[[int], int],
make_buckets_equal_fn: Callable,
max_seq_len_per_rank: float,
get_total_workload_fn: Callable,
delta: float = 0.05,
strategy: str = "dp",
eps_bucket: float = 0.10,
) -> Tuple[List[List[int]], List[Tuple[int, int]], List[float], List[List[int]]]:
"""Form one balanced micro-batch group across DPxCP ranks.
This is a standalone version of the scheduling algorithm extracted from
DefaultDynamicCPScheduler so it can live in a utils module.
Extra args compared to the method version:
gpus_needed_fn: callable(seq_len) -> int
make_buckets_equal_fn: callable(sample_seqlens, compute_estimator) -> list[deque]
max_seq_len_per_rank: max tokens per rank for packing
get_total_workload_fn: callable(seq_len, cp_size) -> float
"""
if not sample_seqlens:
return (
[[] for _ in range(total_gpus)],
[],
[0.0 for _ in range(total_gpus)],
[[] for _ in range(total_gpus)],
)
buckets = make_buckets_equal_fn(sample_seqlens, compute_estimator)
micro_batches = [[] for _ in range(total_gpus)]
exec_times = [0.0 for _ in range(total_gpus)]
sample_ids_per_gpu = [[] for _ in range(total_gpus)]
packing_sequence_len = {}
gpu_group_id = [None] * total_gpus
group_members = {}
group_size = {}
next_gid = 0
pp_cursor = 0
prev_needed = None
check_balance = False
while buckets:
sample_seq_tuple = bucket_idx = None
needed = None
scan_order = (
range(len(buckets))
if strategy == "dp"
else [(pp_cursor + i) % len(buckets) for i in range(len(buckets))]
)
for idx in scan_order:
if not buckets[idx]:
continue
cand_tuple = buckets[idx][0]
cand_seq_len = cand_tuple[1]
needed = gpus_needed_fn(cand_seq_len)
candidate_gids = [gid for gid, sz in group_size.items() if sz == needed]
free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None]
if candidate_gids or len(free_ranks) >= needed:
sample_seq_tuple, bucket_idx = cand_tuple, idx
break
if sample_seq_tuple is None:
break
if strategy == "pp":
pp_cursor = (bucket_idx + 1) % len(buckets)
sample_id, seq_len = sample_seq_tuple
needed = gpus_needed_fn(seq_len)
if prev_needed is None:
prev_needed = needed
candidate_gids = [
gid
for gid, sz in group_size.items()
if sz == needed and packing_sequence_len[gid] + seq_len / needed <= max_seq_len_per_rank
]
if candidate_gids:
best_gid, best_load = min(
((gid, max(exec_times[r] for r in group_members[gid])) for gid in candidate_gids),
key=lambda t: t[1],
)
else:
best_gid, best_load = None, float("inf")
free_ranks = [r for r, gid in enumerate(gpu_group_id) if gid is None]
if len(free_ranks) >= needed:
free_sorted = sorted(free_ranks, key=lambda r: exec_times[r])
new_members = free_sorted[:needed]
new_load = exec_times[new_members[-1]]
if new_load < best_load:
best_gid = None
chosen_members = new_members
else:
chosen_members = group_members[best_gid]
else:
if best_gid is None:
break
chosen_members = group_members[best_gid]
if best_gid is None:
best_gid = next_gid
next_gid += 1
group_members[best_gid] = chosen_members
group_size[best_gid] = needed
for r in chosen_members:
gpu_group_id[r] = best_gid
per_gpu_cost = compute_estimator(seq_len)
packing_sequence_len[best_gid] = packing_sequence_len.get(best_gid, 0) + seq_len / needed
for r in chosen_members:
micro_batches[r].append(seq_len)
exec_times[r] += per_gpu_cost
sample_ids_per_gpu[r].append(sample_id)
buckets[bucket_idx].popleft()
while buckets and not buckets[0]:
buckets.pop(0)
pp_cursor %= max(1, len(buckets))
if needed < prev_needed:
check_balance = True
if (
check_balance
and buckets
and max(exec_times) - min(exec_times) <= delta * max(exec_times)
):
break
leftovers = []
for b in buckets:
for sample_seq_tuple in b:
leftovers.append(sample_seq_tuple)
def trim_overload():
while True:
cur_max = max(exec_times)
cur_min = min(exec_times)
cur_slack = cur_max - cur_min
if cur_slack <= delta * cur_max:
break
if cur_min == 0:
break
max_r = exec_times.index(cur_max)
gid = gpu_group_id[max_r]
members = group_members[gid]
if not micro_batches[max_r] or len(micro_batches[max_r]) <= 1:
break
seq = micro_batches[max_r][-1]
per_gpu_cost = compute_estimator(seq)
proj_times = exec_times[:]
for r in members:
proj_times[r] -= per_gpu_cost
proj_slack = max(proj_times) - min(proj_times)
if proj_slack < cur_slack:
sample_id_to_remove = sample_ids_per_gpu[max_r][-1]
for r in members:
micro_batches[r].pop()
exec_times[r] -= per_gpu_cost
sample_ids_per_gpu[r].pop()
leftovers.append((sample_id_to_remove, seq))
else:
break
# TODO(tailaim): uncomment this to support different ranks have different num_microbatches
# trim_overload()
total_work_before = sum(len(mb) for mb in micro_batches)
def fill_empty_gpus(micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size):
empty_gpus = [i for i in range(total_gpus) if not micro_batches[i]]
if not empty_gpus:
return (micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size)
existing_group_sizes = set(group_size.values())
assert (
existing_group_sizes
), "There should be at least one group existing, cannot redistribute, "
"try to increase 'max-seqlen-per-dp-cp-rank'."
min_group_size = min(existing_group_sizes)
next_power = min(min_group_size * 2, total_gpus)
for gid, size in group_size.items():
if size == min_group_size:
members = group_members[gid]
needed_count = next_power - min_group_size
group_start_gpu = members[0]
group_end_gpu = members[-1]
empty_gpu = [idx for idx, work in enumerate(micro_batches) if not work][0]
assert not all(
work for work in micro_batches[empty_gpu : empty_gpu + needed_count]
), "Empty GPUs were detected but not enough to expand."
work_to_push = micro_batches[group_end_gpu + 1 : empty_gpu]
exec_times_to_push = exec_times[group_end_gpu + 1 : empty_gpu]
sample_ids_to_push = sample_ids_per_gpu[group_end_gpu + 1 : empty_gpu]
new_micro_batches = [[]] * len(micro_batches)
new_exec_times = [0.0] * len(exec_times)
new_sample_ids_per_gpu = [[]] * len(sample_ids_per_gpu)
for i in range(group_start_gpu):
new_micro_batches[i] = micro_batches[i]
new_exec_times[i] = exec_times[i]
new_sample_ids_per_gpu[i] = sample_ids_per_gpu[i]
for i in range(group_start_gpu, group_end_gpu + needed_count + 1):
new_micro_batches[i] = micro_batches[group_end_gpu]
new_exec_times[i] = get_total_workload_fn(
micro_batches[group_end_gpu][0], next_power
)
new_sample_ids_per_gpu[i] = sample_ids_per_gpu[group_end_gpu]
for i, work in enumerate(work_to_push):
new_micro_batches[group_end_gpu + needed_count + 1 + i] = work
new_exec_times[group_end_gpu + needed_count + 1 + i] = exec_times_to_push[i]
new_sample_ids_per_gpu[group_end_gpu + needed_count + 1 + i] = (
sample_ids_to_push[i]
)
group_size[gid] = next_power
group_members[gid] = list(range(members[0], members[-1] + needed_count + 1))
for pushed_gid in group_size.keys():
if pushed_gid > gid:
group_members[pushed_gid] = [
x + needed_count for x in group_members[pushed_gid]
]
return (
new_micro_batches,
new_exec_times,
new_sample_ids_per_gpu,
group_members,
group_size,
)
empty_gpus = any([not micro_batches[i] for i in range(total_gpus)])
while empty_gpus:
micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size = fill_empty_gpus(
micro_batches, exec_times, sample_ids_per_gpu, group_members, group_size
)
empty_gpus = any([not micro_batches[i] for i in range(total_gpus)])
total_work_after = sum(len(mb) for mb in micro_batches)
assert (
total_work_after >= total_work_before
), f"Samples were removed: {total_work_before} -> {total_work_after}"
return micro_batches, leftovers, exec_times, sample_ids_per_gpu
def align_sample_id_groups(sample_id_groups: List, microbatch_group_size_per_vp_stage: int) -> List:
"""Align len(sample_id_groups) to microbatch_group_size_per_vp_stage when VPP is enabled.
Standalone version extracted from DefaultDynamicCPScheduler.
"""
multiple = int(microbatch_group_size_per_vp_stage)
remainder = (-len(sample_id_groups)) % multiple
i = len(sample_id_groups) - 1
def split_group(sample_id_group):
total_hdp_ranks = len(sample_id_group)
cu_ranks = [0]
prev_cp_size = 0
while cu_ranks[-1] != total_hdp_ranks:
start_rank = cu_ranks[-1]
sid0 = sample_id_group[start_rank][0]
cp_size = 0
for r in range(start_rank, total_hdp_ranks):
if sid0 in sample_id_group[r]:
cp_size += 1
else:
break
assert (
prev_cp_size == 0 or cp_size <= prev_cp_size
), f"split_group: CP size is not decreasing: prev={prev_cp_size}, cur={cp_size}"
cu_ranks.append(start_rank + cp_size)
prev_cp_size = cp_size
if len(cu_ranks) == 2:
return None, None
k = 0
while cu_ranks[k] < total_hdp_ranks // 2:
k += 1
old_mb = sample_id_group[: cu_ranks[k]] + [[] for _ in range(total_hdp_ranks - cu_ranks[k])]
new_mb = sample_id_group[cu_ranks[k] :] + [[] for _ in range(cu_ranks[k])]
old_mb = fill_empty_by_expanding_cp(old_mb)
new_mb = fill_empty_by_expanding_cp(new_mb)
return new_mb, old_mb
def fill_empty_by_expanding_cp(sample_id_group):
def fill_empty(sample_id_group):
empty_size = sum(1 for x in sample_id_group if len(x) == 0)
i = len(sample_id_group) - 1 - empty_size
prev_cp_size = 0
while i >= 0:
sid0 = sample_id_group[i][0]
cp_size = 0
while sid0 in sample_id_group[i] and i >= 0:
cp_size += 1
i -= 1
if cp_size > prev_cp_size and prev_cp_size != 0:
start_idx = i + 1 + cp_size
end_idx = -empty_size + prev_cp_size if -empty_size + prev_cp_size < 0 else None
sample_id_group[start_idx + 2 * prev_cp_size : end_idx] = sample_id_group[
start_idx + prev_cp_size : -empty_size
]
sample_id_group[start_idx + prev_cp_size : start_idx + 2 * prev_cp_size] = (
sample_id_group[start_idx : start_idx + prev_cp_size]
)
break
elif cp_size <= empty_size and i == -1:
end_idx = -empty_size + cp_size if -empty_size + cp_size < 0 else None
sample_id_group[2 * cp_size : end_idx] = sample_id_group[cp_size:-empty_size]
sample_id_group[cp_size : 2 * cp_size] = sample_id_group[0:cp_size]
break
prev_cp_size = cp_size
return sample_id_group
while len(sample_id_group[-1]) == 0:
sample_id_group = fill_empty(sample_id_group)
return sample_id_group
attempts_since_split = 0
while remainder > 0:
if i < 0:
if attempts_since_split >= len(sample_id_groups):
assert False, 'align_sample_id_groups: no tail microbatch has enough ids to split'
i = len(sample_id_groups) - 1
group1, group2 = split_group(sample_id_groups[i])
if group1 is not None and group2 is not None:
sample_id_groups[i] = group1
sample_id_groups.append(group2)
remainder -= 1
attempts_since_split = 0
else:
attempts_since_split += 1
i -= 1
return sample_id_groups
# =============================================================================
# Workload estimation helpers for dynamic CP scheduling
# =============================================================================
@lru_cache(maxsize=128)
def dcp_gpus_needed(
seq_len: int, max_seq_len_per_rank: int, min_cp_size: int = 1, max_cp_size: Optional[int] = None
) -> int:
"""Number of GPUs needed, rounded up to the next power of 2, clamped to [min_cp_size, max_cp_size]."""
raw = max(1, 2 ** ceil(log2(seq_len / max_seq_len_per_rank)))
clamped = max(min_cp_size, raw)
if max_cp_size is not None:
clamped = min(clamped, max_cp_size)
return clamped
@lru_cache(maxsize=128)
def dcp_get_total_workload(
seq_length: int,
max_seq_len_per_rank: int,
cp_size: Optional[int] = None,
min_cp_size: int = 1,
max_cp_size: Optional[int] = None,
) -> float:
"""Estimate workload of a sub-sample for scheduling balance."""
if cp_size is None:
cp_size = dcp_gpus_needed(seq_length, max_seq_len_per_rank, min_cp_size, max_cp_size)
return (seq_length * seq_length) / cp_size
def dcp_make_buckets_equal(
sample_seqlens: List[Tuple[int, int]],
compute_estimator: Callable,
max_seq_len_per_rank: int,
min_cp_size: int = 1,