Skip to content

Commit ffe8f94

Browse files
committed
add test_wrap_dataloader UT
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
1 parent 86581cd commit ffe8f94

File tree

8 files changed

+325
-129
lines changed

8 files changed

+325
-129
lines changed

megatron/core/datasets/data_schedule.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ def _build_packed_microbatches(
437437
# When VPP is enabled, align num_micro_batches to this multiple.
438438
(
439439
None
440-
if config.virtual_pipeline_model_parallel_size is None
440+
if (config.virtual_pipeline_model_parallel_size is None or
441+
config.virtual_pipeline_model_parallel_size == 1)
441442
else config.microbatch_group_size_per_vp_stage
442443
),
443444
config.hybrid_context_parallel,
@@ -1009,7 +1010,14 @@ def get_groups_and_subsamples(self, sample_id_seqlens):
10091010
single_microbatch = []
10101011

10111012
for i in range(len(sample_id_seqlens)):
1012-
single_microbatch = [i]
1013+
if sum_seqlen + sample_id_seqlens[i][1] <= self.max_seq_len_all_ranks:
1014+
single_microbatch.append(i)
1015+
sum_seqlen += sample_id_seqlens[i][1]
1016+
else:
1017+
packed_id_groups.append(single_microbatch)
1018+
single_microbatch = [i]
1019+
sum_seqlen = sample_id_seqlens[i][1]
1020+
if len(single_microbatch) > 0:
10131021
packed_id_groups.append(single_microbatch)
10141022

10151023
# we want the number of packed sequences to be multiple of dp_size
@@ -1100,6 +1108,8 @@ def check_require_sample_keys(self, batch: List[Dict]):
11001108
# we only fetch it once, rather than iterating num_micro_batches times.
11011109
for key in required_keys:
11021110
if key not in batch[0]:
1111+
#debugmtl
1112+
print(f"key {key} not in batch[0]: {batch[0]}")
11031113
return False
11041114
return True
11051115

@@ -1631,13 +1641,22 @@ def fill_empty(sample_id_group):
16311641
sample_id_group = fill_empty(sample_id_group)
16321642
return sample_id_group
16331643

1644+
attempts_since_split = 0
16341645
while remainder > 0:
1635-
assert i >= 0, f'align_sample_id_groups: no tail microbatch has enough ids to split'
1646+
if i < 0:
1647+
if attempts_since_split >= len(sample_id_groups):
1648+
assert (
1649+
False
1650+
), f'align_sample_id_groups: no tail microbatch has enough ids to split'
1651+
i = len(sample_id_groups) - 1
16361652
group1, group2 = split_group(sample_id_groups[i])
16371653
if group1 is not None and group2 is not None:
16381654
sample_id_groups[i] = group1
16391655
sample_id_groups.append(group2)
16401656
remainder -= 1
1657+
attempts_since_split = 0
1658+
else:
1659+
attempts_since_split += 1
16411660
i -= 1
16421661

16431662
return sample_id_groups
@@ -1704,16 +1723,18 @@ def _broadcast_to_tp_group(item):
17041723

17051724
# data_iterator should return a batch including the following keys.
17061725
batch_keys = [
1707-
'tokens',
1708-
'position_ids',
1709-
'labels',
1710-
'loss_mask',
17111726
'cu_seqlens',
17121727
'cu_seqlens_padded',
17131728
'max_seqlen',
17141729
]
17151730
if hybrid_context_parallel:
17161731
batch_keys.append('local_cp_size')
1732+
if is_first_stage:
1733+
batch_keys.append('tokens')
1734+
batch_keys.append('position_ids')
1735+
if is_last_stage:
1736+
batch_keys.append('labels')
1737+
batch_keys.append('loss_mask')
17171738

17181739
# Get a batch from data_iterator or create an emtpy batch.
17191740
if is_tp_rank_0:

megatron/core/extensions/transformer_engine.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1329,21 +1329,17 @@ def forward(
13291329
"""Forward."""
13301330
if packed_seq_params is not None:
13311331
# If Dynamic CP group is provided, update TE DPA CP group
1332-
if packed_seq_params.cp_group is not None:
1333-
self.cp_group = packed_seq_params.cp_group
1334-
super().set_context_parallel_group(
1335-
self.cp_group,
1336-
torch.distributed.get_process_group_ranks(self.cp_group),
1337-
TEDotProductAttention.cp_stream,
1338-
self.cp_comm_type,
1339-
)
1340-
# If cp_group is None but local_cp_size is provided,
1341-
# Indicates to turn off CP dynamically
1342-
elif packed_seq_params.local_cp_size is not None:
1343-
assert (
1344-
packed_seq_params.local_cp_size == 1
1345-
), "local_cp_size must be == 1 if provided without cp_group"
1346-
super().set_context_parallel_group(None, None, None, self.cp_comm_type)
1332+
if packed_seq_params.local_cp_size is not None:
1333+
if packed_seq_params.local_cp_size == 1:
1334+
super().set_context_parallel_group(None, None, None, self.cp_comm_type)
1335+
else:
1336+
self.cp_group = packed_seq_params.cp_group
1337+
super().set_context_parallel_group(
1338+
self.cp_group,
1339+
torch.distributed.get_process_group_ranks(self.cp_group),
1340+
TEDotProductAttention.cp_stream,
1341+
self.cp_comm_type,
1342+
)
13471343
self.kept_packed_seq_params.discard("cp_group")
13481344
self.kept_packed_seq_params.discard("local_cp_size")
13491345

megatron/core/model_parallel_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,3 +478,8 @@ def __post_init__(self):
478478
"SFT sequence packing requires Transformer Engine >= 2.9.0 "
479479
f"but got {get_te_version()} (TE < 2.9.0 may have convergence issues)."
480480
)
481+
if self.sequence_packing_scheduler == None:
482+
if self.hybrid_context_parallel:
483+
self.sequence_packing_scheduler = "default_hybrid_cp"
484+
else:
485+
self.sequence_packing_scheduler = "naive_sequence_packing"

megatron/core/pipeline_parallel/schedules.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,6 @@ def wrap_iterator_helper(
520520
):
521521
"""Warp data iterator for sequence packing if needed."""
522522
if config.sequence_packing:
523-
num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch = None, None
524523
scheduler_type_map = {
525524
'default_hybrid_cp': PackingScheduler.DEFAULT_HYBRID_CP,
526525
'empty_scheduler_with_packing': PackingScheduler.EMPTY_PACKING,
@@ -707,7 +706,7 @@ def forward_backward_no_pipelining(
707706
):
708707
create_cudagraphs()
709708

710-
if config.sequence_packing:
709+
if config.sequence_packing and not forward_only:
711710
forward_data_store.append(
712711
[num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch]
713712
)
@@ -2091,7 +2090,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None):
20912090
create_cudagraphs()
20922091
nvtx_range_pop(suffix="misc")
20932092

2094-
if config.sequence_packing:
2093+
if config.sequence_packing and not forward_only:
20952094
forward_data_store.append(
20962095
[num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch]
20972096
)
@@ -2489,7 +2488,7 @@ def enable_grad_sync():
24892488
):
24902489
create_cudagraphs()
24912490

2492-
if config.sequence_packing:
2491+
if config.sequence_packing and not forward_only:
24932492
forward_data_store.append(
24942493
[num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch]
24952494
)

megatron/training/arguments.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -815,11 +815,6 @@ def validate_args(args, defaults={}):
815815
# TODO(tailaim): add support for other dispatcher types
816816
print(f"Setting moe_token_dispatcher_type to alltoall for sft sequence packing with pipeline parallelism")
817817
args.moe_token_dispatcher_type = "alltoall"
818-
if args.sequence_packing_scheduler is None:
819-
if args.hybrid_context_parallel:
820-
args.sequence_packing_scheduler = 'default_hybrid_cp'
821-
else:
822-
args.sequence_packing_scheduler = 'naive_sequence_packing'
823818
else:
824819
args.variable_seq_lengths = False
825820

@@ -983,6 +978,9 @@ def validate_args(args, defaults={}):
983978
assert args.context_parallel_size == 1, 'context parallel size must be 1 for hybrid context parallelism'
984979

985980
if args.sequence_packing:
981+
assert not args.create_attention_mask_in_dataloader, \
982+
'Sequence packing does not support create_attention_mask_in_dataloader. ' \
983+
'Please set --no-create-attention-mask-in-dataloader'
986984
# Validate that packed sequence buffer is large enough for single sequences
987985
if args.hybrid_context_parallel:
988986
# packed_buffer_size = hdp_size * max_seqlen_per_rank >= single_seq_max_len
@@ -2932,7 +2930,7 @@ def _add_distributed_args(parser):
29322930
'Requires --max-seqlen-per-dp-cp-rank to be set.')
29332931
group.add_argument('--min-hybrid-context-parallel-size', type=int, default=1,
29342932
help='Minimum size of the hybrid context parallel groups.')
2935-
group.add_argument('--sequence-packing-scheduler', type=str, default='default_hybrid_cp',
2933+
group.add_argument('--sequence-packing-scheduler', type=str, default=None,
29362934
choices=['default_hybrid_cp', 'empty_scheduler_with_packing', 'empty_scheduler_no_packing', 'naive_sequence_packing'],
29372935
help='Scheduler for sequence packing and hybrid context parallel. '
29382936
'naive_sequence_packing: default naive sequence packing scheduler(just THD, no Hybrid-CP, this '

megatron/training/training.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2839,10 +2839,6 @@ def evaluate(
28392839
decoder_seq_length=args.decoder_seq_length,
28402840
forward_only=True,
28412841
)
2842-
if args.sequence_packing:
2843-
# need to drop first two elements which are total_num_tokens and
2844-
# total_sequence_square_sum
2845-
loss_dicts = loss_dicts[2:]
28462842
ft_integration.on_eval_step_end()
28472843
config.timers = get_timers()
28482844

0 commit comments

Comments
 (0)