@@ -592,3 +592,129 @@ def test_skip_train_iteration(self):
592592 train_iterations = range (1 ,10 )
593593 for i in train_iterations :
594594 self .assertTrue (f"iteration { i :8d} /" in cs .out )
595+
596+ def test_pretrain_shared_t5_mlm (self ):
597+ # all in one test
598+ src_dir = self .src_dir
599+ data_dir = f"{ self .data_dir } /gpt2"
600+ output_dir = self .get_auto_remove_tmp_dir () # "./xxx", after=False)
601+ logs_dir = f"{ output_dir } /logs"
602+ Path (logs_dir ).mkdir (parents = True , exist_ok = True )
603+
604+ pp_size , tp_size , dp_size = get_3d_dimensions ()
605+ num_gpus = pp_size * tp_size * dp_size
606+
607+ # TODO @thomasw21 fix once t5 supports pipeline parallelism
608+ dp_size *= pp_size
609+ pp_size = 1
610+
611+ n_samples = 200 # about 37 iterations
612+ exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume
613+ noise_density = 0.15
614+ mean_noise_span_length = 3
615+ encoder_seq_length = 512
616+ decoder_seq_length = 114 # imposed by `noise_density=0.15` and `input_sequence_length = 512`
617+
618+
619+ args = f"""
620+ --tensor-model-parallel-size { tp_size }
621+ --pipeline-model-parallel-size { pp_size }
622+ --distributed-backend nccl
623+
624+ --num-layers 2
625+ --hidden-size 64
626+ --num-attention-heads 2
627+ --decoder-seq-length { decoder_seq_length }
628+ --encoder-seq-length { encoder_seq_length }
629+ --max-position-embeddings 1024
630+ --micro-batch-size 1
631+ --rampup-batch-size 2 2 { n_samples }
632+ --global-batch-size 16
633+ --train-samples { n_samples }
634+
635+ --optimizer adam
636+ --adam-beta1 0.9
637+ --adam-beta2 0.95
638+ --adam-eps 1e-8
639+ --lr 1e-4
640+ --lr-warmup-samples 5
641+ --clip-grad 1.0
642+ --weight-decay 1e-1
643+ --fp16
644+
645+ --log-interval 5
646+ --save-interval 10
647+ --eval-interval 10
648+ --eval-iters 5
649+ --checkpoint-activations
650+ --exit-interval { exit_interval }
651+
652+ --merge-file { data_dir } /gpt2-tiny-merges.txt
653+ --vocab-file { data_dir } /gpt2-tiny-vocab.json
654+ --log-path { logs_dir }
655+ --save { output_dir } /checkpoints
656+ --load { output_dir } /checkpoints
657+ --data-path { data_dir } /meg-gpt2-openwebtext_text_document
658+ --noise-density { noise_density }
659+ --mean-noise-span-length { mean_noise_span_length }
660+ --tensorboard-dir { output_dir } /tensorboard
661+ --tensorboard-queue-size 5
662+ --log-timers-to-tensorboard
663+ --log-batch-size-to-tensorboard
664+ --log-validation-ppl-to-tensorboard
665+
666+ --log-level debug
667+ """ .split ()
668+
669+ ds_args = f"""
670+ --deepspeed
671+ --deepspeed_config { self .test_file_dir_str } /ds_config.json
672+ --zero-stage 1
673+ --deepspeed-activation-checkpointing
674+ """ .split ()
675+
676+ script = [f"{ src_dir } /pretrain_shated_t5_with_mlm.py" ]
677+ launcher = get_launcher (num_gpus )
678+
679+ cmd = launcher + script + args + ds_args
680+ # keep for quick debug
681+ # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die
682+
683+ # 1. test training from scratch (no checkpoint)
684+ with CaptureStdout () as cs :
685+ execute_subprocess_async (cmd , env = self .get_env ())
686+
687+ # test deepspeed is running
688+ self .assertIn ("DeepSpeed info" , cs .out )
689+
690+ # test reports
691+ self .assertIn ("consumed samples" , cs .out )
692+
693+ # test there should be no checkpoint this round
694+ self .assertIn (f"Unable to find latest file at { output_dir } /checkpoints/latest" , cs .out )
695+
696+ # test checkpoint saving
697+ self .assertIn ("successfully saved checkpoint at iteration" , cs .out )
698+
699+ # test tensorboard
700+ tensorboard_files = glob .glob (f"{ output_dir } /tensorboard/events*" )
701+ self .assertEqual (len (tensorboard_files ), 1 , "tensorboard files" )
702+
703+ # 2. test training from checkpoint: resume
704+ # now do it again, this time resuming from the checkpoint
705+ with CaptureStdout () as cs :
706+ execute_subprocess_async (cmd , env = self .get_env ())
707+
708+ # test checkpoint loading
709+ self .assertIn (f"successfully loaded checkpoint from { output_dir } /checkpoints" , cs .out )
710+
711+ # test reports
712+ self .assertIn ("consumed samples" , cs .out )
713+
714+ # test checkpoint saving
715+ self .assertIn ("successfully saved checkpoint at iteration" , cs .out )
716+
717+ # test tensorboard (1 file from the first run, plus 1 now)
718+ tensorboard_files = glob .glob (f"{ output_dir } /tensorboard/events*" )
719+ self .assertEqual (len (tensorboard_files ), 2 , "tensorboard files" )
720+
0 commit comments