Skip to content

Commit 7552d8e

Browse files
committed
feat: add pipeline parallel
1 parent 62ee2a6 commit 7552d8e

File tree

22 files changed

+943
-27
lines changed

22 files changed

+943
-27
lines changed

example/common/utils.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,11 @@ std::vector<int> GetTensorParallelGroupRanks(int rank) {
6161
return ranks;
6262
}
6363

64+
std::vector<int> GetPipelineParallelGroupRanks(int pp_world_size) {
65+
std::vector<int> ranks;
66+
ranks.reserve(pp_world_size);
67+
for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); }
68+
return ranks;
69+
}
70+
6471
} // namespace infini_train

example/common/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ std::vector<int> GetDataParallelGroupRanks(int rank);
1010

1111
std::vector<int> GetTensorParallelGroupRanks(int rank);
1212

13+
std::vector<int> GetPipelineParallelGroupRanks(int rank);
1314
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
1818
#include "infini_train/include/nn/parallel/global.h"
1919
#include "infini_train/include/nn/parallel/parallel_functional.h"
20+
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
2021
#include "infini_train/include/nn/parallel/rank.h"
2122
#include "infini_train/include/nn/parallel/reduce_op_type.h"
2223
#include "infini_train/include/nn/parallel/tensor_parallel.h"
@@ -63,6 +64,11 @@ DEFINE_int32(
6364
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
6465
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
6566
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
67+
DEFINE_bool(
68+
pipeline_parallel, false,
69+
"use pipeline parallelism or not, will always use device=cuda and use all cuda visible devices when set to true");
70+
DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism");
71+
6672
// precision
6773
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
6874

@@ -106,6 +112,7 @@ void Train(const nn::parallel::Rank &rank) {
106112
int ddp_world_size = global::GetDataParallelSize();
107113
int tp_world_size = global::GetTensorParallelSize();
108114
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0;
115+
int pp_world_size = global::GetPipelineParallelSize();
109116

110117
if (FLAGS_sequence_parallel) {
111118
CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0)
@@ -114,9 +121,11 @@ void Train(const nn::parallel::Rank &rank) {
114121

115122
int ddp_rank = 0;
116123
int tp_rank = 0;
124+
int pp_rank = 0;
117125

118126
const ProcessGroup *ddp_pg = nullptr;
119127
const ProcessGroup *tp_pg = nullptr;
128+
const ProcessGroup *pp_pg = nullptr;
120129

121130
if (rank.IsParallel()) {
122131
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
@@ -134,6 +143,12 @@ void Train(const nn::parallel::Rank &rank) {
134143
// NOTE(zbl): Reserved for VocabParallelEmbedding
135144
nn::parallel::tp_rank = tp_rank;
136145
}
146+
147+
if (pp_world_size > 1) {
148+
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
149+
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
150+
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
151+
}
137152
} else {
138153
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
139154
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
@@ -174,6 +189,10 @@ void Train(const nn::parallel::Rank &rank) {
174189
LOG(FATAL) << "Rank " << rank.thread_rank() << ": Datatype " << FLAGS_dtype << " not supported.";
175190
}
176191

192+
//TODO(jym): Temporary implementation before 3D parallelism
193+
if (FLAGS_pipeline_parallel) {
194+
ddp_world_size = 1;
195+
}
177196
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
178197
// before wrapping the model with DistributedDataParallel (DDP).
179198
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
@@ -182,8 +201,17 @@ void Train(const nn::parallel::Rank &rank) {
182201
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
183202
}
184203

185-
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
186-
FLAGS_batch_size, ddp_rank, ddp_world_size);
204+
std::unique_ptr<DataLoader> train_loader;
205+
if (FLAGS_pipeline_parallel) {
206+
train_loader = std::make_unique<DataLoader>(
207+
std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
208+
FLAGS_batch_size * pp_world_size);
209+
} else {
210+
train_loader = std::make_unique<DistributedDataLoader>(
211+
std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size,
212+
ddp_rank, ddp_world_size);
213+
}
214+
187215
std::optional<DistributedDataLoader> val_loader = std::nullopt;
188216
if (!FLAGS_input_val_bin.empty()) {
189217
val_loader = DistributedDataLoader(
@@ -201,16 +229,33 @@ void Train(const nn::parallel::Rank &rank) {
201229
}
202230

203231
// TODO(dcj): support more complex optimizer later
204-
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
232+
auto lr = FLAGS_learning_rate;
233+
auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> &params) {
234+
return std::make_shared<optimizers::SGD>(params, lr);
235+
};
236+
auto optimizer = optimizer_factory(model->Parameters());
205237

206-
auto train_iter = train_loader.begin();
238+
auto train_iter = train_loader->begin();
207239
std::shared_ptr<nn::Module> loss_fn
208240
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
209241
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
210242
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
211243
loss_fn->To(device);
212244
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";
213245

246+
if (FLAGS_pipeline_parallel) {
247+
CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0)
248+
<< "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
249+
<< ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")";
250+
auto shapes = std::vector<std::vector<int64_t>>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
251+
FLAGS_sequence_length, model->GetConfig()["n_embd"]}};
252+
253+
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
254+
pp_rank, optimizer_factory);
255+
}
256+
257+
LOG(INFO) << "start training";
258+
214259
for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
215260
const bool last_step = step == FLAGS_num_iteration;
216261

@@ -234,7 +279,9 @@ void Train(const nn::parallel::Rank &rank) {
234279
}
235280

236281
// model->Train();
237-
optimizer.ZeroGrad();
282+
if (!FLAGS_pipeline_parallel) {
283+
optimizer->ZeroGrad();
284+
}
238285
// if we are trying to overfit a single batch, we reset the loader here
239286
if (FLAGS_overfit_single_batch) {
240287
// train_loader.Reset();
@@ -254,6 +301,19 @@ void Train(const nn::parallel::Rank &rank) {
254301
++train_iter;
255302
x = std::make_shared<Tensor>(x->To(device));
256303
y = std::make_shared<Tensor>(y->To(device));
304+
305+
if (FLAGS_pipeline_parallel) {
306+
lossf = model->TrainStep({x}, {y}, loss_fn);
307+
308+
auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t>{}, DataType::kFLOAT32);
309+
static_cast<float *>(loss_tensor->DataPtr())[0] = lossf;
310+
auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To(device));
311+
function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax);
312+
auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice());
313+
lossf = static_cast<const float *>(loss_copy.DataPtr())[0];
314+
continue;
315+
}
316+
257317
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
258318
// (bs, seq_len, vocab_size)
259319
auto logits = model->Forward({x, y})[0];
@@ -274,17 +334,19 @@ void Train(const nn::parallel::Rank &rank) {
274334
loss->Backward();
275335
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
276336
}
277-
optimizer.Step();
278337

338+
if (!FLAGS_pipeline_parallel) {
339+
optimizer->Step();
340+
}
279341
const auto iter_end = std::chrono::high_resolution_clock::now();
280342
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
281343
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
282344

283345
if (rank.IsMainRank()) {
284-
LOG(ERROR) << std::format(
285-
"step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})",
286-
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size,
287-
tp_world_size, sp_world_size);
346+
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
347+
"DP={}, TP={}, SP={}, PP={})",
348+
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
349+
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
288350

289351
if ((step + 1) % FLAGS_freq_generate_txt == 0) {
290352
if (!tokenizer) {
@@ -304,7 +366,8 @@ int main(int argc, char *argv[]) {
304366
gflags::ParseCommandLineFlags(&argc, &argv, true);
305367
google::InitGoogleLogging(argv[0]);
306368

307-
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel);
369+
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
370+
FLAGS_pipeline_parallel);
308371

309372
// NOTE(dcj): currently we only support single process
310373
if (FLAGS_nthread_per_process > 1) {

example/gpt2/net.cc

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,60 @@ GPT2::GPT2(const GPT2Config &config) : config_(config) {
212212
= module(kLMHeadLayerName).parameter(tp::ColumnParallelLinear::kParamWeightName);
213213
}
214214

215+
class EmbeddingLayer : public nn::Module {
216+
217+
public:
218+
EmbeddingLayer(std::shared_ptr<nn::Module> wte, std::shared_ptr<nn::Module> wpe) {
219+
modules_["wte"] = wte;
220+
modules_["wpe"] = wpe;
221+
}
222+
223+
std::vector<std::shared_ptr<infini_train::Tensor>>
224+
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &inputs) override {
225+
auto &input_ids = inputs[0]; // (bs, seq_len)
226+
const int seq_len = input_ids->Dims()[1];
227+
const auto device = input_ids->GetDevice();
228+
229+
// position ids: [0, 1, ..., seq_len-1]
230+
auto pos_ids = nn::init::Arange(0, seq_len, infini_train::DataType::kINT64, device);
231+
// (bs, seq_len) -> wte -> (bs, seq_len, n_embd)
232+
auto tok_emb = modules_["wte"]->Forward({input_ids})[0];
233+
// (seq_len,) -> wpe -> (seq_len, n_embd)
234+
auto pos_emb = modules_["wpe"]->Forward({pos_ids})[0];
235+
236+
auto output = tok_emb + pos_emb; // (bs, seq_len, n_embd)
237+
238+
return {output};
239+
}
240+
};
241+
242+
std::vector<std::shared_ptr<nn::Module>> GPT2::GetPipelineLayers() {
243+
auto &transformer = modules_[kTransformerLayerName];
244+
245+
std::vector<std::shared_ptr<nn::Module>> layers;
246+
247+
auto embedding_layer = std::make_shared<EmbeddingLayer>(transformer->mutable_module(kWTELayerName),
248+
transformer->mutable_module(kWPELayerName));
249+
layers.push_back(embedding_layer);
250+
251+
auto seq = std::dynamic_pointer_cast<nn::Sequential>(transformer->mutable_module(kHLayerName));
252+
if (seq) {
253+
for (int idx = 0; idx < seq->size(); ++idx) {
254+
auto sub_module = (*seq)[idx];
255+
layers.push_back(sub_module);
256+
}
257+
}
258+
259+
layers.push_back(transformer->mutable_module(kLnFLayerName));
260+
layers.push_back(modules_[kLMHeadLayerName]);
261+
262+
return layers;
263+
}
264+
265+
std::unordered_map<std::string, int64_t> GPT2::GetConfig() const {
266+
return {{"n_embd", config_.n_embd}, {"n_head", config_.n_head}};
267+
}
268+
215269
std::vector<std::shared_ptr<infini_train::Tensor>>
216270
GPT2::Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) {
217271
// (B, T)

example/gpt2/net.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
8989

9090
explicit GPT2(const GPT2Config &config);
9191

92+
std::unordered_map<std::string, int64_t> GetConfig() const override;
93+
std::vector<std::shared_ptr<infini_train::nn::Module>> GetPipelineLayers() override;
94+
9295
std::vector<std::shared_ptr<infini_train::Tensor>>
9396
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
9497

0 commit comments

Comments
 (0)