Skip to content

Commit 48ed32f

Browse files
committed
Shared t5 tests
1 parent 555f1ba commit 48ed32f

File tree

4 files changed

+139
-5
lines changed

4 files changed

+139
-5
lines changed

megatron/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -925,8 +925,8 @@ def __call__(self, parser, args, values, option_string=None):
925925
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
926926
'positions based on how frequently we train on that position.'
927927
'This is mostly used for prefix_lm training')
928-
group.add_argument("--noise_density", type=float, default=None, help="Span corruption noise density")
929-
group.add_argument("--mean_noise_span_length", type=int, default=None, help="Span corruption mean noise span length")
928+
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
929+
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")
930930

931931

932932
return parser

megatron/data/mlm_dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch
55

6-
from megatron import print_rank_0, get_tokenizer
6+
from megatron import print_rank_0, get_tokenizer, get_args
77
from megatron.data.blendable_dataset import BlendableDataset
88
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_
99
from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_
@@ -303,6 +303,14 @@ def __init__(
303303
assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
304304
assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"
305305

306+
args = get_args()
307+
if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None:
308+
# T5 style
309+
assert self.inputs_length == args.encoder_seq_length
310+
assert self.targets_length == args.decoder_seq_length
311+
else:
312+
assert self.inputs_length + self.targets_length == args.seq_length
313+
306314
def __len__(self):
307315
return len(self.samples_mapping)
308316

pretrain_shared_t5.py renamed to pretrain_shared_t5_with_mlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
105105
data_impl=args.data_impl,
106106
splits_string=args.split,
107107
train_valid_test_num_samples=train_val_test_num_samples,
108-
sequence_length=args.seq_length,
108+
sequence_length=args.encoder_seq_length + args.decoder_seq_length,
109109
noise_density=args.noise_density,
110110
mean_noise_span_length=args.mean_noise_span_length,
111111
seed=args.seed,
@@ -137,7 +137,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
137137
splits=splits,
138138
data_impl=args.data_impl,
139139
train_valid_test_num_samples=train_val_test_num_samples,
140-
seq_length=args.seq_length,
140+
seq_length=args.encoder_seq_length + args.decoder_seq_length,
141141
noise_density=args.noise_density,
142142
mean_noise_span_length=args.mean_noise_span_length,
143143
seed=args.seed,

tests/test_training.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,129 @@ def test_skip_train_iteration(self):
592592
train_iterations = range(1,10)
593593
for i in train_iterations:
594594
self.assertTrue(f"iteration {i:8d}/" in cs.out)
595+
596+
def test_pretrain_shared_t5_mlm(self):
597+
# all in one test
598+
src_dir = self.src_dir
599+
data_dir = f"{self.data_dir}/gpt2"
600+
output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False)
601+
logs_dir = f"{output_dir}/logs"
602+
Path(logs_dir).mkdir(parents=True, exist_ok=True)
603+
604+
pp_size, tp_size, dp_size = get_3d_dimensions()
605+
num_gpus = pp_size * tp_size * dp_size
606+
607+
# TODO @thomasw21 fix once t5 supports pipeline parallelism
608+
dp_size *= pp_size
609+
pp_size = 1
610+
611+
n_samples = 200 # about 37 iterations
612+
exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
613+
noise_density=0.15
614+
mean_noise_span_length=3
615+
encoder_seq_length = 512
616+
decoder_seq_length = 114 # imposed by `noise_density=0.15` and `input_sequence_length = 512`
617+
618+
619+
args = f"""
620+
--tensor-model-parallel-size {tp_size}
621+
--pipeline-model-parallel-size {pp_size}
622+
--distributed-backend nccl
623+
624+
--num-layers 2
625+
--hidden-size 64
626+
--num-attention-heads 2
627+
--decoder-seq-length {decoder_seq_length}
628+
--encoder-seq-length {encoder_seq_length}
629+
--max-position-embeddings 1024
630+
--micro-batch-size 1
631+
--rampup-batch-size 2 2 {n_samples}
632+
--global-batch-size 16
633+
--train-samples {n_samples}
634+
635+
--optimizer adam
636+
--adam-beta1 0.9
637+
--adam-beta2 0.95
638+
--adam-eps 1e-8
639+
--lr 1e-4
640+
--lr-warmup-samples 5
641+
--clip-grad 1.0
642+
--weight-decay 1e-1
643+
--fp16
644+
645+
--log-interval 5
646+
--save-interval 10
647+
--eval-interval 10
648+
--eval-iters 5
649+
--checkpoint-activations
650+
--exit-interval {exit_interval}
651+
652+
--merge-file {data_dir}/gpt2-tiny-merges.txt
653+
--vocab-file {data_dir}/gpt2-tiny-vocab.json
654+
--log-path {logs_dir}
655+
--save {output_dir}/checkpoints
656+
--load {output_dir}/checkpoints
657+
--data-path {data_dir}/meg-gpt2-openwebtext_text_document
658+
--noise-density {noise_density}
659+
--mean-noise-span-length {mean_noise_span_length}
660+
--tensorboard-dir {output_dir}/tensorboard
661+
--tensorboard-queue-size 5
662+
--log-timers-to-tensorboard
663+
--log-batch-size-to-tensorboard
664+
--log-validation-ppl-to-tensorboard
665+
666+
--log-level debug
667+
""".split()
668+
669+
ds_args = f"""
670+
--deepspeed
671+
--deepspeed_config {self.test_file_dir_str}/ds_config.json
672+
--zero-stage 1
673+
--deepspeed-activation-checkpointing
674+
""".split()
675+
676+
script = [f"{src_dir}/pretrain_shated_t5_with_mlm.py"]
677+
launcher = get_launcher(num_gpus)
678+
679+
cmd = launcher + script + args + ds_args
680+
# keep for quick debug
681+
# print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
682+
683+
# 1. test training from scratch (no checkpoint)
684+
with CaptureStdout() as cs:
685+
execute_subprocess_async(cmd, env=self.get_env())
686+
687+
# test deepspeed is running
688+
self.assertIn("DeepSpeed info", cs.out)
689+
690+
# test reports
691+
self.assertIn("consumed samples", cs.out)
692+
693+
# test there should be no checkpoint this round
694+
self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out)
695+
696+
# test checkpoint saving
697+
self.assertIn("successfully saved checkpoint at iteration", cs.out)
698+
699+
# test tensorboard
700+
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
701+
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")
702+
703+
# 2. test training from checkpoint: resume
704+
# now do it again, this time resuming from the checkpoint
705+
with CaptureStdout() as cs:
706+
execute_subprocess_async(cmd, env=self.get_env())
707+
708+
# test checkpoint loading
709+
self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out)
710+
711+
# test reports
712+
self.assertIn("consumed samples", cs.out)
713+
714+
# test checkpoint saving
715+
self.assertIn("successfully saved checkpoint at iteration", cs.out)
716+
717+
# test tensorboard (1 file from the first run, plus 1 now)
718+
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
719+
self.assertEqual(len(tensorboard_files), 2, "tensorboard files")
720+

0 commit comments

Comments
 (0)