Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions example/gpt2/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
Expand Down Expand Up @@ -49,6 +50,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -190,31 +192,31 @@ void Train(const nn::parallel::Rank &rank) {

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::SGD>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
std::dynamic_pointer_cast<GPT2>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank.thread_rank());
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
rank.thread_rank(), ddp_config);
}
}
} else if (ddp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());
auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
}

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
Expand All @@ -237,6 +239,37 @@ void Train(const nn::parallel::Rank &rank) {
tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
}

// TODO(dcj): support more complex optimizer later
// auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate);
auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate);
std::shared_ptr<Optimizer> optimizer = nullptr;

if (FLAGS_use_distributed_optimizer) {
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers;
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups;

if (pp_world_size > 1 && ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
auto buffers
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->param_grad_buffers();
auto groups
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->bucket_groups();
param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end());
bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end());
}
} else if (ddp_world_size > 1) {
param_grad_buffers = dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers();
bucket_groups = dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups();
}

optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
param_grad_buffers, bucket_groups, ddp_pg,
ddp_world_size, ddp_rank);
} else {
optimizer = optimizer_creator(model->Parameters());
}

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(
Expand All @@ -245,11 +278,17 @@ void Train(const nn::parallel::Rank &rank) {
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;

LOG(INFO) << "start training";

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

if (cuda_device) {
cuda_device->ResetMemPoolHighWatermarks();
}

const auto iter_start = std::chrono::high_resolution_clock::now();

// once in a while evaluate the validation dataset
Expand All @@ -276,7 +315,7 @@ void Train(const nn::parallel::Rank &rank) {
float lossf = 0.0f;
// model->Train();
if (pp_world_size == 1) {
optimizer.ZeroGrad();
optimizer->ZeroGrad();

// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
Expand Down Expand Up @@ -315,7 +354,7 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
}

optimizer.Step();
optimizer->Step();
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -324,7 +363,7 @@ void Train(const nn::parallel::Rank &rank) {
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, loss_fn, dtype);
lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
}

if (ddp_world_size > 1) {
Expand All @@ -339,10 +378,16 @@ void Train(const nn::parallel::Rank &rank) {
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.IsLastRank()) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
size_t used_mb = 0, reserved_mb = 0;
if (cuda_device) {
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
}

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
if (tokenizer) {
Expand Down
74 changes: 60 additions & 14 deletions example/llama3/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "infini_train/include/nn/modules/loss.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
#include "infini_train/include/nn/parallel/distributed_optimizer.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/nn/parallel/pp/pipeline_parallel.h"
#include "infini_train/include/nn/parallel/rank.h"
Expand Down Expand Up @@ -48,6 +49,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation");
DEFINE_uint32(text_length, 64, "the length of the generated text");
// optimization
DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations");
DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)");
// evaluation
DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?");
DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
Expand Down Expand Up @@ -170,31 +172,32 @@ void Train(const nn::parallel::Rank &rank) {

auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size);

// TODO(dcj): support more complex optimizer later
auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);

if (pp_world_size > 1) {
// NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct
// when sequence parallelism (SP) is enabled, we need to divide by sp_world_size.
auto shapes = std::vector<std::vector<int64_t>>{
{FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}};

model = std::make_shared<nn::parallel::PipelineParallel>(
model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared<optimizers::Adam>(optimizer),
rank.thread_rank(), std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(),
std::dynamic_pointer_cast<LLaMA3>(model)->GetChunkSize());
if (ddp_world_size > 1) {
auto ddp_config
= DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
(*mutable_chunks)[chunk_id]
= std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id), rank.thread_rank());
(*mutable_chunks)[chunk_id] = std::make_shared<DistributedDataParallel>(mutable_chunks->at(chunk_id),
rank.thread_rank(), ddp_config);
}
}
} else if (ddp_world_size > 1) {
// NOTE(dcj): Complete all device (.to(device)) and dtype (.to(dtype)) conversions
// before wrapping the model with DistributedDataParallel (DDP).
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
// are created during the conversion.
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank());

auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer};
model = std::make_shared<DistributedDataParallel>(model, rank.thread_rank(), ddp_config);
}

DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
Expand All @@ -216,16 +219,53 @@ void Train(const nn::parallel::Rank &rank) {
tokenizer = std::make_unique<Tokenizer>(FLAGS_tokenizer_bin);
}

// TODO(dcj): support more complex optimizer later
// auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate);
auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate);
std::shared_ptr<Optimizer> optimizer = nullptr;

if (FLAGS_use_distributed_optimizer) {
std::vector<std::shared_ptr<ParamAndGradBuffer>> param_grad_buffers;
std::vector<std::shared_ptr<ParamAndGradBucketGroup>> bucket_groups;

if (pp_world_size > 1 && ddp_world_size > 1) {
auto *mutable_chunks = dynamic_cast<nn::parallel::PipelineParallel *>(model.get())->mutable_chunks();
for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) {
auto buffers
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->param_grad_buffers();
auto groups
= dynamic_cast<DistributedDataParallel *>(mutable_chunks->at(chunk_id).get())->bucket_groups();
param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end());
bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end());
}
} else if (ddp_world_size > 1) {
param_grad_buffers = dynamic_cast<DistributedDataParallel *>(model.get())->param_grad_buffers();
bucket_groups = dynamic_cast<DistributedDataParallel *>(model.get())->bucket_groups();
}

optimizer = std::make_shared<nn::parallel::DistributedOptimizer>(optimizer_creator, model->Parameters(),
param_grad_buffers, bucket_groups, ddp_pg,
ddp_world_size, ddp_rank);
} else {
optimizer = optimizer_creator(model->Parameters());
}

auto train_iter = train_loader.begin();
std::shared_ptr<nn::Module> loss_fn
= (tp_world_size > 1) ? std::static_pointer_cast<nn::Module>(std::make_shared<VocabParallelCrossEntropyLoss>())
: std::static_pointer_cast<nn::Module>(std::make_shared<nn::CrossEntropyLoss>());
loss_fn->To(device);
LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training";

auto cuda_device = device->IsCUDA() ? dynamic_cast<const CudaDevice *>(device) : nullptr;

for (int step = 0; step < FLAGS_num_iteration + 1; ++step) {
const bool last_step = step == FLAGS_num_iteration;

if (cuda_device) {
cuda_device->ResetMemPoolHighWatermarks();
}

const auto iter_start = std::chrono::high_resolution_clock::now();

// once in a while evaluate the validation dataset
Expand All @@ -252,7 +292,7 @@ void Train(const nn::parallel::Rank &rank) {
float lossf = 0.0f;
if (pp_world_size == 1) {
// model->Train();
optimizer.ZeroGrad();
optimizer->ZeroGrad();

// if we are trying to overfit a single batch, we reset the loader here
if (FLAGS_overfit_single_batch) {
Expand Down Expand Up @@ -291,7 +331,7 @@ void Train(const nn::parallel::Rank &rank) {
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward";
}

optimizer.Step();
optimizer->Step();
} else {
auto [x, y] = *train_iter;
// if we are trying to overfit a single batch, we reset the loader here by commenting out the line below
Expand All @@ -300,7 +340,7 @@ void Train(const nn::parallel::Rank &rank) {
x = std::make_shared<Tensor>(x->To(device));
y = std::make_shared<Tensor>(y->To(device));

lossf = model->TrainStep({x}, {y}, loss_fn, dtype);
lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype);
}

if (ddp_world_size > 1) {
Expand All @@ -315,10 +355,16 @@ void Train(const nn::parallel::Rank &rank) {
const double tps = FLAGS_total_batch_size / (duration_us / 1e6);

if (rank.IsLastRank()) {
LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, "
"DP={}, TP={}, SP={}, PP={})",
size_t used_mb = 0, reserved_mb = 0;
if (cuda_device) {
std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB();
}

LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | "
"peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})",
step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f,
tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size);
tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size,
pp_world_size);

if ((step + 1) % FLAGS_freq_generate_txt == 0) {
// FIXME(jym): to support PP
Expand Down
3 changes: 3 additions & 0 deletions infini_train/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ class CudaDevice : public Device {

nn::parallel::Rank rank() const override;

void ResetMemPoolHighWatermarks() const;
std::pair<size_t, size_t> GetMemPoolPeakMB() const;

private:
CudaDevice(int8_t index);

Expand Down
4 changes: 3 additions & 1 deletion infini_train/include/nn/modules/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
namespace infini_train {
class Tensor;
class Device;
class Optimizer;
} // namespace infini_train

namespace infini_train::nn {
Expand Down Expand Up @@ -53,7 +54,8 @@ class Module : public std::enable_shared_from_this<Module> {
virtual std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors);

virtual float TrainStep(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
const std::vector<std::shared_ptr<Tensor>> &targets, const std::shared_ptr<Module> &loss_fn,
const std::vector<std::shared_ptr<Tensor>> &targets,
const std::shared_ptr<Optimizer> &optimizer, const std::shared_ptr<Module> &loss_fn,
DataType dtype) {
return 0.0f;
};
Expand Down
Loading