Skip to content

Commit 9c47bc2

Browse files
committed
feat: Pipeline parallelism divides the model into chunks during construction
1 parent 962509c commit 9c47bc2

File tree

16 files changed

+634
-501
lines changed

16 files changed

+634
-501
lines changed

example/gpt2/main.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ void Train(const nn::parallel::Rank &rank) {
148148
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
149149
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
150150
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
151+
152+
nn::parallel::pp_rank = pp_rank;
151153
}
152154
} else {
153155
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
@@ -243,8 +245,9 @@ void Train(const nn::parallel::Rank &rank) {
243245
CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0)
244246
<< "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
245247
<< ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")";
246-
auto shapes = std::vector<std::vector<int64_t>>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
247-
FLAGS_sequence_length, model->GetConfig()["n_embd"]}};
248+
249+
auto shapes = std::vector<std::vector<int64_t>>{
250+
{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches, FLAGS_sequence_length, model_config.n_embd}};
248251

249252
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
250253
pp_rank, optimizer_factory);
@@ -298,9 +301,9 @@ void Train(const nn::parallel::Rank &rank) {
298301
x = std::make_shared<Tensor>(x->To(device));
299302
y = std::make_shared<Tensor>(y->To(device));
300303

304+
// FIXME(jym): without gradient accumulation
301305
if (pp_world_size > 1) {
302306
lossf = model->TrainStep({x}, {y}, loss_fn);
303-
304307
auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t>{}, DataType::kFLOAT32);
305308
static_cast<float *>(loss_tensor->DataPtr())[0] = lossf;
306309
auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To(device));

example/gpt2/net.cc

Lines changed: 342 additions & 223 deletions
Large diffs are not rendered by default.

example/gpt2/net.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,6 @@ class GPT2 : public infini_train::nn::CloneableModule<GPT2> {
8989

9090
explicit GPT2(const GPT2Config &config);
9191

92-
std::unordered_map<std::string, int64_t> GetConfig() const override;
93-
std::vector<std::shared_ptr<infini_train::nn::Module>> GetPipelineLayers() override;
94-
9592
std::vector<std::shared_ptr<infini_train::Tensor>>
9693
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
9794

example/llama3/main.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ void Train(const nn::parallel::Rank &rank) {
130130
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
131131
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
132132
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
133+
134+
nn::parallel::pp_rank = pp_rank;
133135
}
134136
} else {
135137
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
@@ -222,8 +224,8 @@ void Train(const nn::parallel::Rank &rank) {
222224
<< "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
223225
<< ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")";
224226

225-
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size * pp_world_size / FLAGS_num_microbatches,
226-
FLAGS_sequence_length, model->GetConfig()["n_embd"]}};
227+
auto shapes = std::vector<std::vector<int64_t>>{
228+
{FLAGS_batch_size * pp_world_size / FLAGS_num_microbatches, FLAGS_sequence_length, model_config.n_embd}};
227229

228230
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
229231
pp_rank, optimizer_factory);

example/llama3/net.cc

Lines changed: 202 additions & 174 deletions
Large diffs are not rendered by default.

example/llama3/net.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ class LLaMA3 : public infini_train::nn::CloneableModule<LLaMA3> {
129129

130130
explicit LLaMA3(const LLaMA3Config &config);
131131

132-
std::vector<std::shared_ptr<infini_train::nn::Module>> GetPipelineLayers() override;
133-
134-
std::unordered_map<std::string, int64_t> GetConfig() const;
135-
136132
std::vector<std::shared_ptr<infini_train::Tensor>>
137133
Forward(const std::vector<std::shared_ptr<infini_train::Tensor>> &x) override;
138134

infini_train/include/nn/modules/container.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "infini_train/include/nn/modules/module.h"
8+
#include <iostream>
89

910
namespace infini_train {
1011
class Tensor;
@@ -53,4 +54,24 @@ class ModuleList : public CloneableModule<ModuleList> {
5354
private:
5455
std::vector<std::shared_ptr<Module>> module_list_;
5556
};
57+
58+
class PipelineModuleList : public CloneableModule<PipelineModuleList> {
59+
public:
60+
static constexpr char kType[] = "PipelineModuleList";
61+
62+
explicit PipelineModuleList(std::vector<std::shared_ptr<nn::Module>> &&modules, int64_t global_start_index);
63+
64+
std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict() const override;
65+
66+
std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;
67+
68+
auto begin() { return local_modules_.begin(); }
69+
auto end() { return local_modules_.end(); }
70+
auto begin() const { return local_modules_.begin(); }
71+
auto end() const { return local_modules_.end(); }
72+
73+
private:
74+
std::vector<std::shared_ptr<nn::Module>> local_modules_;
75+
int64_t global_start_;
76+
};
5677
} // namespace infini_train::nn

infini_train/include/nn/modules/module.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,10 @@ class Module : public std::enable_shared_from_this<Module> {
4343
std::shared_ptr<Module> mutable_module(const std::string &name);
4444
const Module &module(const std::string &name) const;
4545

46-
std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict() const;
46+
virtual std::unordered_map<std::string, std::shared_ptr<Tensor>> StateDict() const;
4747

4848
virtual std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors);
4949

50-
virtual std::vector<std::shared_ptr<Module>> GetPipelineLayers() { return {}; }
51-
5250
virtual std::unordered_map<std::string, int64_t> GetConfig() const { return {}; }
5351

5452
virtual float TrainStep(const std::vector<std::shared_ptr<Tensor>> &input_tensors,

infini_train/include/nn/parallel/pp/pipeline_parallel.h

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
namespace infini_train::nn::parallel {
1515

16+
extern thread_local int pp_rank;
17+
1618
using OptimizerFactory = std::function<std::shared_ptr<Optimizer>(const std::vector<std::shared_ptr<Tensor>> &params)>;
1719

1820
class PipelineParallel : public Module {
@@ -26,22 +28,14 @@ class PipelineParallel : public Module {
2628
private:
2729
int num_stages_;
2830
int rank_;
29-
std::vector<const Device *> devices_;
3031
std::shared_ptr<Module> original_model_;
3132
std::shared_ptr<PipelineStage> pipeline_stage_;
3233
std::shared_ptr<PipelineSchedule> schedule_;
3334

34-
std::vector<std::vector<std::shared_ptr<Module>>>
35-
SplitLayersIntoStages(std::vector<std::shared_ptr<Module>> layers);
36-
37-
void SplitModel(const std::vector<std::vector<int64_t>> &recv_shape, OptimizerFactory optimizer_factory);
38-
39-
std::vector<std::shared_ptr<Optimizer>>
40-
CreateOptimizers(const std::vector<std::vector<std::shared_ptr<Module>>> &stage_layers,
41-
OptimizerFactory optimizer_factory);
35+
std::shared_ptr<Optimizer> CreateOptimizer(const std::shared_ptr<Module> &model,
36+
OptimizerFactory optimizer_factory);
4237

43-
void BuildPipelineStage(const std::vector<std::vector<std::shared_ptr<Module>>> &stage_layers,
44-
const std::vector<std::shared_ptr<Optimizer>> &optimizers,
38+
void BuildPipelineStage(const std::shared_ptr<Module> &model, const std::shared_ptr<Optimizer> &optimizers,
4539
const std::vector<std::vector<int64_t>> &recv_shape);
4640

4741
void SetupSchedule(int num_microbatches);

infini_train/include/nn/parallel/pp/pipeline_stage.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace infini_train::nn::parallel {
1212

1313
class PipelineStage {
1414
public:
15-
PipelineStage(const std::vector<std::shared_ptr<Module>> &layers, int stage_index, int num_stages,
15+
PipelineStage(const std::shared_ptr<Module> &model, int stage_index, int num_stages,
1616
const std::vector<std::vector<int64_t>> &recvShape, std::shared_ptr<Optimizer> optim);
1717

1818
std::vector<std::shared_ptr<Tensor>> ForwardOneChunk(const std::vector<std::shared_ptr<Tensor>> &inputs);
@@ -34,7 +34,7 @@ class PipelineStage {
3434
int next_rank_;
3535
const Device *device_ = nullptr;
3636
std::vector<std::vector<int64_t>> recv_shape_;
37-
std::vector<std::shared_ptr<Module>> layers_;
37+
std::shared_ptr<Module> model_;
3838
std::shared_ptr<Optimizer> optim_;
3939
};
4040

0 commit comments

Comments
 (0)