@@ -64,10 +64,7 @@ DEFINE_int32(
6464 " When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices." );
6565DEFINE_uint32 (tensor_parallel, 1 , " Tensor Parallel world size" );
6666DEFINE_bool (sequence_parallel, false , " Whether to enable Sequence Parallel" );
67- DEFINE_uint32 (
68- pipeline_parallel, 1 ,
69- " Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true" );
70- DEFINE_uint32 (num_microbatches, 4 , " the num of microbatches in pipeline parallelism" );
67+ DEFINE_uint32 (pipeline_parallel, 1 , " Pipeline Parallel world size, specified the number of PP stages." );
7168
7269// precision
7370DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
@@ -148,14 +145,16 @@ void Train(const nn::parallel::Rank &rank) {
148145 pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (
149146 GetPipelineParallelProcessGroupName (rank.thread_rank ()), GetPipelineParallelGroupRanks (pp_world_size));
150147 pp_rank = pp_pg->GetGroupRank (rank.thread_rank ());
148+
149+ nn::parallel::pp_rank = pp_rank;
151150 }
152151 } else {
153152 device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance ()->GetDefaultDevice ()
154153 : DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , 0 );
155154 }
156155
157156 // calculate gradient accumulation from the desired total batch size and the current run configuration
158- const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ( ddp_world_size * pp_world_size) ;
157+ const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size;
159158 CHECK_EQ (FLAGS_total_batch_size % tokens_per_fwdbwd, 0 );
160159 const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd;
161160 LOG (INFO) << " total desired batch size: " << FLAGS_total_batch_size
@@ -197,16 +196,10 @@ void Train(const nn::parallel::Rank &rank) {
197196 model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank ());
198197 }
199198
200- std::unique_ptr<DataLoader> train_loader;
201- if (pp_world_size > 1 ) {
202- train_loader = std::make_unique<DataLoader>(
203- std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
204- FLAGS_batch_size * pp_world_size);
205- } else {
206- train_loader = std::make_unique<DistributedDataLoader>(
207- std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size,
208- ddp_rank, ddp_world_size);
209- }
199+ auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
200+ DistributedDataLoader train_loader (std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
201+ pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
202+ ddp_rank, ddp_world_size);
210203
211204 std::optional<DistributedDataLoader> val_loader = std::nullopt ;
212205 if (!FLAGS_input_val_bin.empty ()) {
@@ -225,13 +218,9 @@ void Train(const nn::parallel::Rank &rank) {
225218 }
226219
227220 // TODO(dcj): support more complex optimizer later
228- auto lr = FLAGS_learning_rate;
229- auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> ¶ms) {
230- return std::make_shared<optimizers::SGD>(params, lr);
231- };
232- auto optimizer = optimizer_factory (model->Parameters ());
221+ auto optimizer = optimizers::SGD (model->Parameters (), FLAGS_learning_rate);
233222
234- auto train_iter = train_loader-> begin ();
223+ auto train_iter = train_loader. begin ();
235224 std::shared_ptr<nn::Module> loss_fn
236225 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(
237226 std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size ))
@@ -240,14 +229,10 @@ void Train(const nn::parallel::Rank &rank) {
240229 LOG (INFO) << " Rank " << rank.thread_rank () << " : start training" ;
241230
242231 if (pp_world_size > 1 ) {
243- CHECK_EQ ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0 )
244- << " FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
245- << " ) must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << " )" ;
246- auto shapes = std::vector<std::vector<int64_t >>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
247- FLAGS_sequence_length, model->GetConfig ()[" n_embd" ]}};
248-
249- model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
250- pp_rank, optimizer_factory);
232+ auto shapes = std::vector<std::vector<int64_t >>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd }};
233+
234+ model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
235+ pp_rank, std::make_shared<optimizers::SGD>(optimizer));
251236 }
252237
253238 LOG (INFO) << " start training" ;
@@ -274,81 +259,80 @@ void Train(const nn::parallel::Rank &rank) {
274259 break ;
275260 }
276261
277- // model->Train();
278- if (pp_world_size == 1 ) {
279- optimizer->ZeroGrad ();
280- }
281- // if we are trying to overfit a single batch, we reset the loader here
282- if (FLAGS_overfit_single_batch) {
283- // train_loader.Reset();
284- }
285- float lossf = 0 .0f ;
286262#ifdef PROFILE_MODE
287263 Profiler::Instance ().SetTag (" Step_" + std::to_string (step));
288264#endif
289- for (int micro_step = 0 ; micro_step < grad_accum_steps; ++micro_step) {
290- // enable autocast for the current step
291- infini_train::AutocastGuard autocast_guard (device->Type (), dtype);
292265
293- // (bs, seq_len), (bs, seq_len)
266+ float lossf = 0 .0f ;
267+ // model->Train();
268+ if (pp_world_size == 1 ) {
269+ optimizer.ZeroGrad ();
270+
271+ // if we are trying to overfit a single batch, we reset the loader here
272+ if (FLAGS_overfit_single_batch) {
273+ // train_loader.Reset();
274+ }
275+
276+ for (int micro_step = 0 ; micro_step < grad_accum_steps; ++micro_step) {
277+ // enable autocast for the current step
278+ infini_train::AutocastGuard autocast_guard (device->Type (), dtype);
279+
280+ // (bs, seq_len), (bs, seq_len)
281+ auto [x, y] = *train_iter;
282+ // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
283+ // TODO(dcj): support dataloader.reset() later
284+ ++train_iter;
285+ x = std::make_shared<Tensor>(x->To (device));
286+ y = std::make_shared<Tensor>(y->To (device));
287+
288+ LOG (INFO) << " Rank " << rank.thread_rank () << " : start forward" ;
289+ // (bs, seq_len, vocab_size)
290+ auto logits = model->Forward ({x, y})[0 ];
291+ LOG (INFO) << " Rank " << rank.thread_rank () << " : finish model forward, start loss forward" ;
292+ auto loss = loss_fn->Forward ({logits, y})[0 ];
293+ loss = loss / grad_accum_steps;
294+
295+ // disable autocast for the current step (backward is not under autocast)
296+ autocast_guard.Disable ();
297+
298+ LOG (INFO) << " Rank " << rank.thread_rank () << " : finish loss forward" ;
299+ if (ddp_world_size > 1 ) {
300+ function::AllReduce (loss, function::ReduceOpType::kAvg );
301+ }
302+ auto loss_cpu = loss->To (DeviceManager::Instance ()->GetDefaultDevice ());
303+ lossf += static_cast <const float *>(loss_cpu.DataPtr ())[0 ];
304+ LOG (INFO) << " Rank " << rank.thread_rank () << " : start backward" ;
305+ loss->Backward ();
306+ LOG (INFO) << " Rank " << rank.thread_rank () << " : finish backward" ;
307+ }
308+
309+ optimizer.Step ();
310+ } else {
294311 auto [x, y] = *train_iter;
295312 // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
296313 // TODO(dcj): support dataloader.reset() later
297314 ++train_iter;
298315 x = std::make_shared<Tensor>(x->To (device));
299316 y = std::make_shared<Tensor>(y->To (device));
300317
301- if (pp_world_size > 1 ) {
302- lossf = model->TrainStep ({x}, {y}, loss_fn);
303-
304- auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t >{}, DataType::kFLOAT32 );
305- static_cast <float *>(loss_tensor->DataPtr ())[0 ] = lossf;
306- auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To (device));
307- function::AllReduce (loss_device_ptr, function::ReduceOpType::kMax );
308- auto loss_copy = loss_device_ptr->To (DeviceManager::Instance ()->GetDefaultDevice ());
309- lossf = static_cast <const float *>(loss_copy.DataPtr ())[0 ];
310- continue ;
311- }
312-
313- LOG (INFO) << " Rank " << rank.thread_rank () << " : start forward" ;
314- // (bs, seq_len, vocab_size)
315- auto logits = model->Forward ({x, y})[0 ];
316- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish model forward, start loss forward" ;
317- auto loss = loss_fn->Forward ({logits, y})[0 ];
318- loss = loss / grad_accum_steps;
319-
320- // disable autocast for the current step (backward is not under autocast)
321- autocast_guard.Disable ();
322-
323- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish loss forward" ;
324- if (ddp_world_size > 1 ) {
325- function::AllReduce (loss, function::ReduceOpType::kAvg );
326- }
327- auto loss_cpu = loss->To (DeviceManager::Instance ()->GetDefaultDevice ());
328- lossf += static_cast <const float *>(loss_cpu.DataPtr ())[0 ];
329- LOG (INFO) << " Rank " << rank.thread_rank () << " : start backward" ;
330- loss->Backward ();
331- LOG (INFO) << " Rank " << rank.thread_rank () << " : finish backward" ;
332- }
333-
334- if (pp_world_size == 1 ) {
335- optimizer->Step ();
318+ lossf = model->TrainStep ({x}, {y}, loss_fn);
336319 }
337320 const auto iter_end = std::chrono::high_resolution_clock::now ();
338321 const double duration_us = std::chrono::duration<double , std::micro>(iter_end - iter_start).count ();
339322 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
340323
341- if (rank.IsMainRank () ) {
324+ if (rank.thread_rank () == pp_world_size - 1 ) {
342325 LOG (ERROR) << std::format (" step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
343326 " DP={}, TP={}, SP={}, PP={})" ,
344327 step + 1 , FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
345328 tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
346329
347330 if ((step + 1 ) % FLAGS_freq_generate_txt == 0 ) {
348- if (!tokenizer) {
349- continue ;
331+ if (tokenizer) {
332+ // FIXME(jym): to support PP
333+ CHECK_EQ (pp_world_size, 1 );
334+ tokenizer->GenerateText (*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
350335 }
351- tokenizer->GenerateText (*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
352336 }
353337 }
354338 }
0 commit comments