Skip to content

Commit 3f9564f

Browse files
committed
need reset
Signed-off-by: tailaim <tailaim@nvidia.com>
1 parent 669e0f3 commit 3f9564f

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

megatron/core/datasets/data_schedule.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1641,13 +1641,22 @@ def fill_empty(sample_id_group):
16411641
sample_id_group = fill_empty(sample_id_group)
16421642
return sample_id_group
16431643

1644+
attempts_since_split = 0
16441645
while remainder > 0:
1645-
assert i >= 0, f'align_sample_id_groups: no tail microbatch has enough ids to split'
1646+
if i < 0:
1647+
if attempts_since_split >= len(sample_id_groups):
1648+
assert (
1649+
False
1650+
), f'align_sample_id_groups: no tail microbatch has enough ids to split'
1651+
i = len(sample_id_groups) - 1
16461652
group1, group2 = split_group(sample_id_groups[i])
16471653
if group1 is not None and group2 is not None:
16481654
sample_id_groups[i] = group1
16491655
sample_id_groups.append(group2)
16501656
remainder -= 1
1657+
attempts_since_split = 0
1658+
else:
1659+
attempts_since_split += 1
16511660
i -= 1
16521661

16531662
return sample_id_groups

tests/unit_tests/context_parallel/test_packing_and_hybrid_cp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,10 +523,7 @@ def set_tp_pp_vpp(tp, pp, cp, vpp=None, destroy_first=True):
523523
)
524524
model = model if isinstance(model, list) else [model]
525525

526-
#debugmtl
527-
print("daddy is here!!!",flush=True)
528526
data_iterator = get_data_iterator(args)
529-
print("daddy is out!!!",flush=True)
530527

531528
forward_backward_func = get_forward_backward_func()
532529
losses_reduced = forward_backward_func(
@@ -976,6 +973,9 @@ def _create_single_sample(seq_len):
976973
# Call the function under test
977974
(new_data_iterator, num_micro_batches, num_total_tokens_this_global_batch, sequence_square_sum_this_global_batch) = wrap_dataloader(data_iterator, config, scheduler_type)
978975

976+
#debugmtl
977+
print(f"rank:{torch.distributed.get_rank()}, exit wrap_dataloader",flush=True)
978+
979979
# check the result
980980
assert type(num_micro_batches) is int
981981
assert type(num_total_tokens_this_global_batch) is float
@@ -1025,5 +1025,8 @@ def _check_batch(batch_all, batch_keys):
10251025
assert new_data_iterator is None
10261026

10271027
finally:
1028+
#debugmtl
1029+
if torch.distributed.get_rank() == 0:
1030+
print(f"rank:0, exit test_wrap_dataloader successfully with tp:{tp}, pp:{pp}, cp:{cp}, vpp:{vpp}, scheduler_type:{scheduler_type}",flush=True)
10281031
Utils.destroy_model_parallel()
10291032
unset_global_variables()

0 commit comments

Comments
 (0)