Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 85 additions & 42 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
#include "infini_train/include/nn/parallel/rank.h"
#include "infini_train/include/nn/parallel/reduce_op_type.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
Expand Down Expand Up @@ -63,6 +64,8 @@ DEFINE_int32(
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
DEFINE_uint32(tensor_parallel, 1, "Tensor Parallel world size");
DEFINE_bool(sequence_parallel, false, "Whether to enable Sequence Parallel");
DEFINE_uint32(pipeline_parallel, 1, "Pipeline Parallel world size, specified the number of PP stages.");

// precision
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");

Expand Down Expand Up @@ -106,6 +109,7 @@ void Train(const nn::parallel::Rank &rank) {
int ddp_world_size = global::GetDataParallelSize();
int tp_world_size = global::GetTensorParallelSize();
int sp_world_size = global::GetSequenceParallelEnabled() ? tp_world_size : 0;
int pp_world_size = global::GetPipelineParallelSize();

if (FLAGS_sequence_parallel) {
CHECK_EQ(FLAGS_sequence_length % tp_world_size, 0)
Expand All @@ -114,9 +118,11 @@ void Train(const nn::parallel::Rank &rank) {

int ddp_rank = 0;
int tp_rank = 0;
int pp_rank = 0;

const ProcessGroup *ddp_pg = nullptr;
const ProcessGroup *tp_pg = nullptr;
const ProcessGroup *pp_pg = nullptr;

if (rank.IsParallel()) {
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
Expand All @@ -134,6 +140,14 @@ void Train(const nn::parallel::Rank &rank) {
// NOTE(zbl): Reserved for VocabParallelEmbedding
nn::parallel::tp_rank = tp_rank;
}

if (pp_world_size > 1) {
pp_pg = ProcessGroupFactory::Instance()->GetOrCreate(
GetPipelineParallelProcessGroupName(rank.thread_rank()), GetPipelineParallelGroupRanks(pp_world_size));
pp_rank = pp_pg->GetGroupRank(rank.thread_rank());

nn::parallel::pp_rank = pp_rank;
}
} else {
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
Expand Down Expand Up @@ -182,8 +196,11 @@ void Train(const nn::parallel::Rank &rank) {
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
}

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
FLAGS_batch_size, ddp_rank, ddp_world_size);
pp_world_size > 1 ? FLAGS_batch_size * num_micro_batches : FLAGS_batch_size,
ddp_rank, ddp_world_size);

std::optional<DistributedDataLoader> val_loader = std::nullopt;
if (!FLAGS_input_val_bin.empty()) {
val_loader = DistributedDataLoader(
Expand Down Expand Up @@ -211,6 +228,15 @@ void Train(const nn::parallel::Rank &rank) {
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.thread_rank() << ": start training";

if (pp_world_size > 1) {
auto shapes = std::vector<std::vector<int64_t>>{{FLAGS_batch_size, FLAGS_sequence_length, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(model, pp_world_size, num_micro_batches, shapes,
pp_rank, std::make_shared<optimizers::SGD>(optimizer));
}

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

Expand All @@ -233,64 +259,80 @@ void Train(const nn::parallel::Rank &rank) {
break;
}

// model->Train();
optimizer.ZeroGrad();
// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
// train_loader.Reset();
}
float lossf = 0.0f;
#ifdef PROFILE_MODE
Profiler::Instance().SetTag("Step_" + std::to_string(step));
#endif
for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);

// (bs, seq_len), (bs, seq_len)
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
optimizer.ZeroGrad();

// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
// train_loader.Reset();
}

for (int micro_step = 0; micro_step < grad_accum_steps; ++micro_step) {
// enable autocast for the current step
infini_train::AutocastGuard autocast_guard(device->Type(), dtype);

// (bs, seq_len), (bs, seq_len)
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
autocast_guard.Disable();

LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
if (ddp_world_size > 1) {
function::AllReduce(loss, function::ReduceOpType::kAvg);
}
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
loss->Backward();
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
}

optimizer.Step();
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
// TODO(dcj): support dataloader.reset() later
++train_iter;
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));
LOG(INFO) << "Rank " << rank.thread_rank() << ": start forward";
// (bs, seq_len, vocab_size)
auto logits = model->Forward({x, y})[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish model forward, start loss forward";
auto loss = loss_fn->Forward({logits, y})[0];
loss = loss / grad_accum_steps;

// disable autocast for the current step (backward is not under autocast)
autocast_guard.Disable();

LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
if (ddp_world_size > 1) {
function::AllReduce(loss, function::ReduceOpType::kAvg);
}
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
lossf += static_cast<const float *>(loss_cpu.DataPtr())[0];
LOG(INFO) << "Rank " << rank.thread_rank() << ": start backward";
loss->Backward();
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish backward";
}
optimizer.Step();

lossf = model->TrainStep({x}, {y}, loss_fn);
}
const auto iter_end = std::chrono::high_resolution_clock::now();
const double duration_us = std::chrono::duration<double, std::micro>(iter_end - iter_start).count();
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.IsMainRank()) {
LOG(ERROR) << std::format(
"step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, DP={}, TP={}, SP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, tps, ddp_world_size,
tp_world_size, sp_world_size);
if (rank.thread_rank() == pp_world_size - 1) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
if (!tokenizer) {
continue;
if (tokenizer) {
// FIXME(jym): to support PP
CHECK_EQ(pp_world_size, 1);
tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
}
tokenizer->GenerateText(*model, FLAGS_batch_size, FLAGS_sequence_length, FLAGS_text_length, device);
}
}
}
Expand All @@ -304,7 +346,8 @@ int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);

nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel);
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel, FLAGS_sequence_parallel,
FLAGS_pipeline_parallel);

LOG(INFO) << nn::parallel::global::ProcessGroupOverview();

Expand Down
Loading