Skip to content

Commit ee6bd7a

Browse files
committed
feat: Pipeline parallelism divides the model into chunks during construction
1 parent 962509c commit ee6bd7a

File tree

22 files changed

+795
-841
lines changed

22 files changed

+795
-841
lines changed

example/common/utils.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,4 @@ void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t s
6161
ifs.seekg(base + std::streamoff(len * sizeof(float)));
6262
}
6363

64-
std::vector<int> GetPipelineParallelGroupRanks(int pp_world_size) {
65-
std::vector<int> ranks;
66-
ranks.reserve(pp_world_size);
67-
for (int i = 0; i < pp_world_size; ++i) { ranks.push_back(i); }
68-
return ranks;
69-
}
70-
7164
} // namespace infini_train

example/common/utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,4 @@ void ReadVectorAllFloat(std::ifstream &ifs, float *dst, int64_t len);
3030

3131
void ReadVectorShardFloat(std::ifstream &ifs, float *dst, int64_t len, int64_t start, int64_t cnt);
3232

33-
std::vector<int> GetPipelineParallelGroupRanks(int rank);
3433
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 65 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ DEFINE_int32(
6464
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
6565
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
6666
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
67-
DEFINE_uint32(
68-
pipeline_parallel, 1,
69-
"Pipeline Parallel world size, will always use device=cuda and use all cuda visible devices when set to true");
70-
DEFINE_uint32(num_microbatches, 4, "the num of microbatches in pipeline parallelism");
67+
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");
7168

7269
// precision
7370
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
@@ -148,14 +145,16 @@ void Train(const nn::parallel::Rank &rank) {
148145
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
149146
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
150147
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());
148+
149+
nn::parallel::pp_rank = pp_rank;
151150
}
152151
} else {
153152
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
154153
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
155154
}
156155

157156
// calculate gradient accumulation from the desired total batch size and the current run configuration
158-
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * (ddp_world_size * pp_world_size);
157+
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size;
159158
CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0);
160159
const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd;
161160
LOG(INFO) << "total desired batch size: " << FLAGS_total_batch_size
@@ -197,16 +196,10 @@ void Train(const nn::parallel::Rank &rank) {
197196
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
198197
}
199198

200-
std::unique_ptr<DataLoader> train_loader;
201-
if (pp_world_size > 1) {
202-
train_loader = std::make_unique<DataLoader>(
203-
std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
204-
FLAGS_batch_size * pp_world_size);
205-
} else {
206-
train_loader = std::make_unique<DistributedDataLoader>(
207-
std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length), FLAGS_batch_size,
208-
ddp_rank, ddp_world_size);
209-
}
199+
auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
200+
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
201+
pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
202+
ddp_rank, ddp_world_size);
210203

211204
std::optional<DistributedDataLoader> val_loader = std::nullopt;
212205
if (!FLAGS_input_val_bin.empty()) {
@@ -225,13 +218,9 @@ void Train(const nn::parallel::Rank &rank) {
225218
}
226219

227220
// TODO(dcj): support more complex optimizer later
228-
auto lr = FLAGS_learning_rate;
229-
auto optimizer_factory = [lr](const std::vector<std::shared_ptr<Tensor>> &params) {
230-
return std::make_shared<optimizers::SGD>(params, lr);
231-
};
232-
auto optimizer = optimizer_factory(model->Parameters());
221+
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
233222

234-
auto train_iter = train_loader->begin();
223+
auto train_iter = train_loader.begin();
235224
std::shared_ptr<nn::Module> loss_fn
236225
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
237226
std::make_shared<VocabParallelCrossEntropyLoss>(model_config.original_vocab_size))
@@ -240,14 +229,10 @@ void Train(const nn::parallel::Rank &rank) {
240229
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";
241230

242231
if (pp_world_size > 1) {
243-
CHECK_EQ((FLAGS_batch_size * pp_world_size) % FLAGS_num_microbatches, 0)
244-
<< "FLAGS_batch_size (" << (FLAGS_batch_size * pp_world_size)
245-
<< ") must be divisible by FLAGS_num_microbatches (" << FLAGS_num_microbatches << ")";
246-
auto shapes = std::vector<std::vector<int64_t>>{{(FLAGS_batch_size * pp_world_size) / FLAGS_num_microbatches,
247-
FLAGS_sequence_length, model->GetConfig()["n_embd"]}};
248-
249-
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, FLAGS_num_microbatches, shapes,
250-
pp_rank, optimizer_factory);
232+
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};
233+
234+
model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
235+
pp_rank, std::make_shared<optimizers::SGD>(optimizer));
251236
}
252237

253238
LOG(INFO) << "start training";
@@ -274,81 +259,80 @@ void Train(const nn::parallel::Rank &rank) {
274259
break;
275260
}
276261

277-
// model->Train();
278-
if (pp_world_size == 1) {
279-
optimizer->ZeroGrad();
280-
}
281-
// if we are trying to overfit a single batch, we reset the loader here
282-
if (FLAGS_overfit_single_batch) {
283-
// train_loader.Reset();
284-
}
285-
float lossf = 0.0f;
286262
#ifdef PROFILE_MODE
287263
Profiler::Instance().SetTag("Step_" + std::to_string(step));
288264
#endif
289-
for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
290-
// enable autocast for the current step
291-
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);
292265

293-
// (bs, seq_len), (bs, seq_len)
266+
float lossf = 0.0f;
267+
// model->Train();
268+
if (pp_world_size == 1) {
269+
optimizer.ZeroGrad();
270+
271+
// if we are trying to overfit a single batch, we reset the loader here
272+
if (FLAGS_overfit_single_batch) {
273+
// train_loader.Reset();
274+
}
275+
276+
for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
277+
// enable autocast for the current step
278+
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);
279+
280+
// (bs, seq_len), (bs, seq_len)
281+
auto [x, y] = *train_iter;
282+
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
283+
// TODO(dcj): support dataloader.reset() later
284+
++train_iter;
285+
x = std::make_shared<Tensor>(x->To(device));
286+
y = std::make_shared<Tensor>(y->To(device));
287+
288+
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
289+
// (bs, seq_len, vocab_size)
290+
auto logits = model->Forward({x, y})[0];
291+
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
292+
auto loss = loss_fn->Forward({logits, y})[0];
293+
loss = loss / grad_accum_steps;
294+
295+
// disable autocast for the current step (backward is not under autocast)
296+
autocast_guard.Disable();
297+
298+
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
299+
if (ddp_world_size > 1) {
300+
function::AllReduce(loss, function::ReduceOpType::kAvg);
301+
}
302+
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
303+
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
304+
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
305+
loss->Backward();
306+
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
307+
}
308+
309+
optimizer.Step();
310+
} else {
294311
auto [x, y] = *train_iter;
295312
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
296313
// TODO(dcj): support dataloader.reset() later
297314
++train_iter;
298315
x = std::make_shared<Tensor>(x->To(device));
299316
y = std::make_shared<Tensor>(y->To(device));
300317

301-
if (pp_world_size > 1) {
302-
lossf = model->TrainStep({x}, {y}, loss_fn);
303-
304-
auto loss_tensor = std::make_shared<Tensor>(std::vector<int64_t>{}, DataType::kFLOAT32);
305-
static_cast<float *>(loss_tensor->DataPtr())[0] = lossf;
306-
auto loss_device_ptr = std::make_shared<Tensor>(loss_tensor->To(device));
307-
function::AllReduce(loss_device_ptr, function::ReduceOpType::kMax);
308-
auto loss_copy = loss_device_ptr->To(DeviceManager::Instance()->GetDefaultDevice());
309-
lossf = static_cast<const float *>(loss_copy.DataPtr())[0];
310-
continue;
311-
}
312-
313-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
314-
// (bs, seq_len, vocab_size)
315-
auto logits = model->Forward({x, y})[0];
316-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
317-
auto loss = loss_fn->Forward({logits, y})[0];
318-
loss = loss / grad_accum_steps;
319-
320-
// disable autocast for the current step (backward is not under autocast)
321-
autocast_guard.Disable();
322-
323-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
324-
if (ddp_world_size > 1) {
325-
function::AllReduce(loss, function::ReduceOpType::kAvg);
326-
}
327-
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
328-
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
329-
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
330-
loss->Backward();
331-
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
332-
}
333-
334-
if (pp_world_size == 1) {
335-
optimizer->Step();
318+
lossf = model->TrainStep({x}, {y}, loss_fn);
336319
}
337320
const auto iter_end = std::chrono::high_resolution_clock::now();
338321
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
339322
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);
340323

341-
if (rank.IsMainRank()) {
324+
if (rank.thread_rank() == pp_world_size - 1) {
342325
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
343326
"DP={}, TP={}, SP={}, PP={})",
344327
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
345328
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
346329

347330
if ((step + 1) % FLAGS_freq_generate_txt == 0) {
348-
if (!tokenizer) {
349-
continue;
331+
if (tokenizer) {
332+
// FIXME(jym): to support PP
333+
CHECK_EQ(pp_world_size, 1);
334+
tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
350335
}
351-
tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
352336
}
353337
}
354338
}

0 commit comments

Comments
 (0)