1717#include " infini_train/include/nn/parallel/distributed_data_parallel.h"
1818#include " infini_train/include/nn/parallel/global.h"
1919#include " infini_train/include/nn/parallel/parallel_functional.h"
20+ #include " infini_train/include/nn/parallel/pp/pipeline_parallel.h"
2021#include " infini_train/include/nn/parallel/rank.h"
2122#include " infini_train/include/nn/parallel/reduce_op_type.h"
2223#include " infini_train/include/nn/parallel/tensor_parallel.h"
@@ -63,6 +64,11 @@ DEFINE_int32(
6364 " When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices." );
6465DEFINE_uint32 (tensor_parallel, 1 , " Tensor Parallel world size" );
6566DEFINE_bool (sequence_parallel, false , " Whether to enable Sequence Parallel" );
67+ DEFINE_bool (
68+ pipeline_parallel, false ,
69+ " use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true" );
70+ DEFINE_uint32 (num_microbatches, 4 , " the num of microbatches in pipeline parallelism" );
71+
6672// precision
6773DEFINE_string (dtype, " float32" , " precision used in training (float32/bfloat16)" );
6874
@@ -106,6 +112,7 @@ void Train(const nn::parallel::Rank &rank) {
106112 int ddp_world_size = global::GetDataParallelSize ();
107113 int tp_world_size = global::GetTensorParallelSize ();
108114 int sp_world_size = global::GetSequenceParallelEnabled () ? tp_world_size : 0 ;
115+ int pp_world_size = global::GetPipelineParallelSize ();
109116
110117 if (FLAGS_sequence_parallel) {
111118 CHECK_EQ (FLAGS_sequence_length % tp_world_size, 0 )
@@ -114,9 +121,11 @@ void Train(const nn::parallel::Rank &rank) {
114121
115122 int ddp_rank = 0 ;
116123 int tp_rank = 0 ;
124+ int pp_rank = 0 ;
117125
118126 const ProcessGroup *ddp_pg = nullptr ;
119127 const ProcessGroup *tp_pg = nullptr ;
128+ const ProcessGroup *pp_pg = nullptr ;
120129
121130 if (rank.IsParallel ()) {
122131 device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , rank.thread_rank ());
@@ -134,6 +143,12 @@ void Train(const nn::parallel::Rank &rank) {
134143 // NOTE(zbl): Reserved for VocabParallelEmbedding
135144 nn::parallel::tp_rank = tp_rank;
136145 }
146+
147+ if (pp_world_size > 1 ) {
148+ pp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (
149+ GetPipelineParallelProcessGroupName (rank.thread_rank ()), GetPipelineParallelGroupRanks (pp_world_size));
150+ pp_rank = pp_pg->GetGroupRank (rank.thread_rank ());
151+ }
137152 } else {
138153 device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance ()->GetDefaultDevice ()
139154 : DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , 0 );
@@ -174,6 +189,10 @@ void Train(const nn::parallel::Rank &rank) {
174189 LOG (FATAL) << " Rank " << rank.thread_rank () << " : Datatype " << FLAGS_dtype << " not supported." ;
175190 }
176191
192+ // TODO(jym): Temporary implementation before 3D parallelism
193+ if (FLAGS_pipeline_parallel) {
194+ ddp_world_size = 1 ;
195+ }
177196 // NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
178197 // before wrapping the model with DistributedDataParallel (DDP).
179198 // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
@@ -182,8 +201,17 @@ void Train(const nn::parallel::Rank &rank) {
182201 model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank ());
183202 }
184203
185- DistributedDataLoader train_loader (std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
186- FLAGS_batch_size, ddp_rank, ddp_world_size);
204+ std::unique_ptr<DataLoader> train_loader;
205+ if (FLAGS_pipeline_parallel) {
206+ train_loader = std::make_unique<DataLoader>(
207+ std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
208+ FLAGS_batch_size * pp_world_size);
209+ } else {
210+ train_loader = std::make_unique<DistributedDataLoader>(
211+ std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size,
212+ ddp_rank, ddp_world_size);
213+ }
214+
187215 std::optional<DistributedDataLoader> val_loader = std::nullopt ;
188216 if (!FLAGS_input_val_bin.empty ()) {
189217 val_loader = DistributedDataLoader (
@@ -201,16 +229,33 @@ void Train(const nn::parallel::Rank &rank) {
201229 }
202230
203231 // TODO(dcj): support more complex optimizer later
204- auto optimizer = optimizers::SGD (model->Parameters (), FLAGS_learning_rate);
232+ auto lr = FLAGS_learning_rate;
233+ auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> ¶ms) {
234+ return std::make_shared<optimizers::SGD>(params, lr);
235+ };
236+ auto optimizer = optimizer_factory (model->Parameters ());
205237
206- auto train_iter = train_loader. begin ();
238+ auto train_iter = train_loader-> begin ();
207239 std::shared_ptr<nn::Module> loss_fn
208240 = (tp_world_size > 1 ) ? std::static_pointer_cast<nn::Module>(
209241 std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size ))
210242 : std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
211243 loss_fn->To (device);
212244 LOG (INFO) << " Rank " << rank.thread_rank () << " : start training" ;
213245
246+ if (FLAGS_pipeline_parallel) {
247+ CHECK_EQ ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0 )
248+ << " FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
249+ << " ) must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << " )" ;
250+ auto shapes = std::vector<std::vector<int64_t >>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
251+ FLAGS_sequence_length, model->GetConfig ()[" n_embd" ]}};
252+
253+ model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
254+ pp_rank, optimizer_factory);
255+ }
256+
257+ LOG (INFO) << " start training" ;
258+
214259 for (int step = 0 ; step < FLAGS_num_iteration + 1 ; ++step) {
215260 const bool last_step = step == FLAGS_num_iteration;
216261
@@ -234,7 +279,9 @@ void Train(const nn::parallel::Rank &rank) {
234279 }
235280
236281 // model->Train();
237- optimizer.ZeroGrad ();
282+ if (!FLAGS_pipeline_parallel) {
283+ optimizer->ZeroGrad ();
284+ }
238285 // if we are trying to overfit a single batch, we reset the loader here
239286 if (FLAGS_overfit_single_batch) {
240287 // train_loader.Reset();
@@ -254,6 +301,19 @@ void Train(const nn::parallel::Rank &rank) {
254301 ++train_iter;
255302 x = std::make_shared<Tensor>(x->To (device));
256303 y = std::make_shared<Tensor>(y->To (device));
304+
305+ if (FLAGS_pipeline_parallel) {
306+ lossf = model->TrainStep ({x}, {y}, loss_fn);
307+
308+ auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t >{}, DataType::kFLOAT32 );
309+ static_cast <float *>(loss_tensor->DataPtr ())[0 ] = lossf;
310+ auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To (device));
311+ function::AllReduce (loss_device_ptr, function::ReduceOpType::kMax );
312+ auto loss_copy = loss_device_ptr->To (DeviceManager::Instance ()->GetDefaultDevice ());
313+ lossf = static_cast <const float *>(loss_copy.DataPtr ())[0 ];
314+ continue ;
315+ }
316+
257317 LOG (INFO) << " Rank " << rank.thread_rank () << " : start forward" ;
258318 // (bs, seq_len, vocab_size)
259319 auto logits = model->Forward ({x, y})[0 ];
@@ -274,17 +334,19 @@ void Train(const nn::parallel::Rank &rank) {
274334 loss->Backward ();
275335 LOG (INFO) << " Rank " << rank.thread_rank () << " : finish backward" ;
276336 }
277- optimizer.Step ();
278337
338+ if (!FLAGS_pipeline_parallel) {
339+ optimizer->Step ();
340+ }
279341 const auto iter_end = std::chrono::high_resolution_clock::now ();
280342 const double duration_us = std::chrono::duration<double , std::micro>(iter_end - iter_start).count ();
281343 const double tps = FLAGS_total_batch_size / (duration_us / 1e6 );
282344
283345 if (rank.IsMainRank ()) {
284- LOG (ERROR) << std::format (
285- " step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})" ,
286- step + 1 , FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size ,
287- tp_world_size, sp_world_size);
346+ LOG (ERROR) << std::format (" step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
347+ " DP={}, TP={}, SP={}, PP ={})" ,
348+ step + 1 , FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
349+ tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size );
288350
289351 if ((step + 1 ) % FLAGS_freq_generate_txt == 0 ) {
290352 if (!tokenizer) {
@@ -304,7 +366,8 @@ int main(int argc, char *argv[]) {
304366 gflags::ParseCommandLineFlags (&argc, &argv, true );
305367 google::InitGoogleLogging (argv[0 ]);
306368
307- nn::parallel::global::InitAllEnv (FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel);
369+ nn::parallel::global::InitAllEnv (FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
370+ FLAGS_pipeline_parallel);
308371
309372 // NOTE(dcj): currently we only support single process
310373 if (FLAGS_nthread_per_process > 1 ) {
0 commit comments