@@ -64,6 +64,7 @@ DEFINE_int32(
6464DEFINE_uint32 (tensor_parallel, 1 , " Tensor Parallel world size" );
6565DEFINE_bool (sequence_parallel, false , " Whether to enable Sequence Parallel" );
6666DEFINE_uint32 (pipeline_parallel, 1 , " Pipeline Parallel world size, specified the number of PP stages." );
67+ DEFINE_uint32 (virtual_pipeline_parallel, 1 , " Number of chunks in PP stage." );
6768
6869// precision
6970DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
@@ -187,15 +188,35 @@ void Train(const nn::parallel::Rank &rank) {
187188 LOG (FATAL) << " Rank " << rank.GlobalRank () << " : Datatype " << FLAGS_dtype << " not supported." ;
188189 }
189190
190- // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
191- // before wrapping the model with DistributedDataParallel (DDP).
192- // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
193- // are created during the conversion.
194- if (ddp_world_size > 1 ) {
191+ auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
192+
193+ // TODO(dcj): support more complex optimizer later
194+ auto optimizer = optimizers::SGD (model->Parameters (), FLAGS_learning_rate);
195+
196+ if (pp_world_size > 1 ) {
197+ // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
198+ // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
199+ auto shapes = std::vector<std::vector<int64_t >>{
200+ {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd }};
201+
202+ model = std::make_shared<nn::parallel::PipelineParallel>(
203+ model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::SGD>(optimizer),
204+ rank.thread_rank (), std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize ());
205+ if (ddp_world_size > 1 ) {
206+ auto *mutable_chunks = dynamic_cast <nn::parallel::PipelineParallel *>(model.get ())->mutable_chunks ();
207+ for (int chunk_id = 0 ; chunk_id < mutable_chunks->size (); ++chunk_id) {
208+ (*mutable_chunks)[chunk_id]
209+ = std::make_shared<DistributedDataParallel>(mutable_chunks->at (chunk_id), rank.thread_rank ());
210+ }
211+ }
212+ } else if (ddp_world_size > 1 ) {
213+ // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
214+ // before wrapping the model with DistributedDataParallel (DDP).
215+ // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
216+ // are created during the conversion.
195217 model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank ());
196218 }
197219
198- auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
199220 DistributedDataLoader train_loader (std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
200221 pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
201222 ddp_rank, ddp_world_size);
@@ -216,9 +237,6 @@ void Train(const nn::parallel::Rank &rank) {
216237 tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
217238 }
218239
219- // TODO(dcj): support more complex optimizer later
220- auto optimizer = optimizers::SGD (model->Parameters (), FLAGS_learning_rate);
221-
222240 auto train_iter = train_loader.begin ();
223241 std::shared_ptr<nn::Module> loss_fn
224242 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(
@@ -227,17 +245,6 @@ void Train(const nn::parallel::Rank &rank) {
227245 loss_fn->To (device);
228246 LOG (INFO) << " Rank " << rank.GlobalRank () << " : start training" ;
229247
230- if (pp_world_size > 1 ) {
231- // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
232- // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
233- auto shapes = std::vector<std::vector<int64_t >>{
234- {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd }};
235-
236- model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
237- pp_rank, std::make_shared<optimizers::SGD>(optimizer),
238- rank.thread_rank ());
239- }
240-
241248 LOG (INFO) << " start training" ;
242249
243250 for (int step = 0 ; step < FLAGS_num_iteration + 1 ; ++step) {
@@ -293,6 +300,7 @@ void Train(const nn::parallel::Rank &rank) {
293300 auto logits = model->Forward ({x, y})[0 ];
294301 LOG (INFO) << " Rank " << rank.GlobalRank () << " : finish model forward, start loss forward" ;
295302 auto loss = loss_fn->Forward ({logits, y})[0 ];
303+ // FIXME(jym): verify gradient accumulation precision
296304 loss = loss / grad_accum_steps;
297305
298306 // disable autocast for the current step (backward is not under autocast)
@@ -356,7 +364,7 @@ int main(int argc, char *argv[]) {
356364 google::InitGoogleLogging (argv[0 ]);
357365
358366 nn::parallel::global::InitAllEnv (FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
359- FLAGS_pipeline_parallel);
367+ FLAGS_pipeline_parallel, FLAGS_virtual_pipeline_parallel );
360368
361369 LOG (INFO) << nn::parallel::global::ProcessGroupOverview ();
362370
0 commit comments