diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index 4a34c464..aa47352a 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/distributed_data_parallel.h" +#include "infini_train/include/nn/parallel/distributed_optimizer.h" #include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" @@ -49,6 +50,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-4, "learning rate warmup iterations"); +DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -190,9 +192,6 @@ void Train(const nn::parallel::Rank &rank) { auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); - if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. @@ -200,13 +199,15 @@ void Train(const nn::parallel::Rank &rank) { {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared( - model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), - rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), + std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { + auto ddp_config + = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { - (*mutable_chunks)[chunk_id] - = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), + rank.thread_rank(), ddp_config); } } } else if (ddp_world_size > 1) { @@ -214,7 +215,8 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - model = std::make_shared(model, rank.thread_rank()); + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + model = std::make_shared(model, rank.thread_rank(), ddp_config); } DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), @@ -237,6 +239,37 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } + // TODO(dcj): support more complex optimizer later + // auto optimizer = optimizers::SGD(model->Parameters(), FLAGS_learning_rate); + auto optimizer_creator = optimizers::SGD::Create(FLAGS_learning_rate); + std::shared_ptr optimizer = nullptr; + + if (FLAGS_use_distributed_optimizer) { + std::vector> param_grad_buffers; + std::vector> bucket_groups; + + if (pp_world_size > 1 && ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + auto buffers + = dynamic_cast(mutable_chunks->at(chunk_id).get())->param_grad_buffers(); + auto groups + = dynamic_cast(mutable_chunks->at(chunk_id).get())->bucket_groups(); + param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end()); + bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end()); + } + } else if (ddp_world_size > 1) { + param_grad_buffers = dynamic_cast(model.get())->param_grad_buffers(); + bucket_groups = dynamic_cast(model.get())->bucket_groups(); + } + + optimizer = std::make_shared(optimizer_creator, model->Parameters(), + param_grad_buffers, bucket_groups, ddp_pg, + ddp_world_size, ddp_rank); + } else { + optimizer = optimizer_creator(model->Parameters()); + } + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast( @@ -245,11 +278,17 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; + auto cuda_device = device->IsCUDA() ? dynamic_cast(device) : nullptr; + LOG(INFO) << "start training"; for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; + if (cuda_device) { + cuda_device->ResetMemPoolHighWatermarks(); + } + const auto iter_start = std::chrono::high_resolution_clock::now(); // once in a while evaluate the validation dataset @@ -276,7 +315,7 @@ void Train(const nn::parallel::Rank &rank) { float lossf = 0.0f; // model->Train(); if (pp_world_size == 1) { - optimizer.ZeroGrad(); + optimizer->ZeroGrad(); // if we are trying to overfit a single batch, we reset the loader here if (FLAGS_overfit_single_batch) { @@ -315,7 +354,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward"; } - optimizer.Step(); + optimizer->Step(); } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -324,7 +363,7 @@ void Train(const nn::parallel::Rank &rank) { x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - lossf = model->TrainStep({x}, {y}, loss_fn, dtype); + lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } if (ddp_world_size > 1) { @@ -339,10 +378,16 @@ void Train(const nn::parallel::Rank &rank) { const double tps = FLAGS_total_batch_size / (duration_us / 1e6); if (rank.IsLastRank()) { - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " - "DP={}, TP={}, SP={}, PP={})", + size_t used_mb = 0, reserved_mb = 0; + if (cuda_device) { + std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB(); + } + + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); + tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { if (tokenizer) { diff --git a/example/llama3/main.cc b/example/llama3/main.cc index fdea2162..aa584aea 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -13,6 +13,7 @@ #include "infini_train/include/nn/modules/loss.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/distributed_data_parallel.h" +#include "infini_train/include/nn/parallel/distributed_optimizer.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/pp/pipeline_parallel.h" #include "infini_train/include/nn/parallel/rank.h" @@ -48,6 +49,7 @@ DEFINE_uint32(freq_generate_txt, 10, "frequency of text generation"); DEFINE_uint32(text_length, 64, "the length of the generated text"); // optimization DEFINE_double(learning_rate, 1e-5, "learning rate warmup iterations"); +DEFINE_bool(use_distributed_optimizer, false, "Whether to enable DistributedOptimizer(only take effects when DP>1)"); // evaluation DEFINE_uint32(val_loss_every, 0, "every how many steps to evaluate val loss?"); DEFINE_uint32(sample_every, 0, "how often to sample from the model?"); @@ -170,9 +172,6 @@ void Train(const nn::parallel::Rank &rank) { auto num_micro_batches = FLAGS_total_batch_size / (FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size); - // TODO(dcj): support more complex optimizer later - auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); - if (pp_world_size > 1) { // NOTE(dcj): To ensure that the tensor shapes at the pipeline stage boundaries remain correct // when sequence parallelism (SP) is enabled, we need to divide by sp_world_size. @@ -180,13 +179,15 @@ void Train(const nn::parallel::Rank &rank) { {FLAGS_batch_size, FLAGS_sequence_length / sp_world_size, model_config.n_embd}}; model = std::make_shared( - model, pp_world_size, num_micro_batches, shapes, pp_rank, std::make_shared(optimizer), - rank.thread_rank(), std::dynamic_pointer_cast(model)->GetChunkSize()); + model, pp_world_size, num_micro_batches, shapes, pp_rank, rank.thread_rank(), + std::dynamic_pointer_cast(model)->GetChunkSize()); if (ddp_world_size > 1) { + auto ddp_config + = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { - (*mutable_chunks)[chunk_id] - = std::make_shared(mutable_chunks->at(chunk_id), rank.thread_rank()); + (*mutable_chunks)[chunk_id] = std::make_shared(mutable_chunks->at(chunk_id), + rank.thread_rank(), ddp_config); } } } else if (ddp_world_size > 1) { @@ -194,7 +195,9 @@ void Train(const nn::parallel::Rank &rank) { // before wrapping the model with DistributedDataParallel (DDP). // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors // are created during the conversion. - model = std::make_shared(model, rank.thread_rank()); + + auto ddp_config = DistributedDataParallelConfig{.use_distributed_optimizer = FLAGS_use_distributed_optimizer}; + model = std::make_shared(model, rank.thread_rank(), ddp_config); } DistributedDataLoader train_loader(std::make_shared(FLAGS_input_bin, FLAGS_sequence_length), @@ -216,6 +219,37 @@ void Train(const nn::parallel::Rank &rank) { tokenizer = std::make_unique(FLAGS_tokenizer_bin); } + // TODO(dcj): support more complex optimizer later + // auto optimizer = optimizers::Adam(model->Parameters(), FLAGS_learning_rate); + auto optimizer_creator = optimizers::Adam::Create(FLAGS_learning_rate); + std::shared_ptr optimizer = nullptr; + + if (FLAGS_use_distributed_optimizer) { + std::vector> param_grad_buffers; + std::vector> bucket_groups; + + if (pp_world_size > 1 && ddp_world_size > 1) { + auto *mutable_chunks = dynamic_cast(model.get())->mutable_chunks(); + for (int chunk_id = 0; chunk_id < mutable_chunks->size(); ++chunk_id) { + auto buffers + = dynamic_cast(mutable_chunks->at(chunk_id).get())->param_grad_buffers(); + auto groups + = dynamic_cast(mutable_chunks->at(chunk_id).get())->bucket_groups(); + param_grad_buffers.insert(param_grad_buffers.end(), buffers.begin(), buffers.end()); + bucket_groups.insert(bucket_groups.end(), groups.begin(), groups.end()); + } + } else if (ddp_world_size > 1) { + param_grad_buffers = dynamic_cast(model.get())->param_grad_buffers(); + bucket_groups = dynamic_cast(model.get())->bucket_groups(); + } + + optimizer = std::make_shared(optimizer_creator, model->Parameters(), + param_grad_buffers, bucket_groups, ddp_pg, + ddp_world_size, ddp_rank); + } else { + optimizer = optimizer_creator(model->Parameters()); + } + auto train_iter = train_loader.begin(); std::shared_ptr loss_fn = (tp_world_size > 1) ? std::static_pointer_cast(std::make_shared()) @@ -223,9 +257,15 @@ void Train(const nn::parallel::Rank &rank) { loss_fn->To(device); LOG(INFO) << "Rank " << rank.GlobalRank() << ": start training"; + auto cuda_device = device->IsCUDA() ? dynamic_cast(device) : nullptr; + for (int step = 0; step < FLAGS_num_iteration + 1; ++step) { const bool last_step = step == FLAGS_num_iteration; + if (cuda_device) { + cuda_device->ResetMemPoolHighWatermarks(); + } + const auto iter_start = std::chrono::high_resolution_clock::now(); // once in a while evaluate the validation dataset @@ -252,7 +292,7 @@ void Train(const nn::parallel::Rank &rank) { float lossf = 0.0f; if (pp_world_size == 1) { // model->Train(); - optimizer.ZeroGrad(); + optimizer->ZeroGrad(); // if we are trying to overfit a single batch, we reset the loader here if (FLAGS_overfit_single_batch) { @@ -291,7 +331,7 @@ void Train(const nn::parallel::Rank &rank) { LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish backward"; } - optimizer.Step(); + optimizer->Step(); } else { auto [x, y] = *train_iter; // if we are trying to overfit a single batch, we reset the loader here by commenting out the line below @@ -300,7 +340,7 @@ void Train(const nn::parallel::Rank &rank) { x = std::make_shared(x->To(device)); y = std::make_shared(y->To(device)); - lossf = model->TrainStep({x}, {y}, loss_fn, dtype); + lossf = model->TrainStep({x}, {y}, optimizer, loss_fn, dtype); } if (ddp_world_size > 1) { @@ -315,10 +355,16 @@ void Train(const nn::parallel::Rank &rank) { const double tps = FLAGS_total_batch_size / (duration_us / 1e6); if (rank.IsLastRank()) { - LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s, " - "DP={}, TP={}, SP={}, PP={})", + size_t used_mb = 0, reserved_mb = 0; + if (cuda_device) { + std::tie(used_mb, reserved_mb) = cuda_device->GetMemPoolPeakMB(); + } + + LOG(ERROR) << std::format("step {:4d}/{} | train loss {:.6f} | lr {:.2e} | ({:.2f} ms | {:.0f} tok/s | " + "peak used: {:5d} MB | peak reserved: {:5d} MB, DP={}, TP={}, SP={}, PP={})", step + 1, FLAGS_num_iteration, lossf, FLAGS_learning_rate, duration_us / 1e3f, - tps, ddp_world_size, tp_world_size, sp_world_size, pp_world_size); + tps, used_mb, reserved_mb, ddp_world_size, tp_world_size, sp_world_size, + pp_world_size); if ((step + 1) % FLAGS_freq_generate_txt == 0) { // FIXME(jym): to support PP diff --git a/infini_train/include/device.h b/infini_train/include/device.h index 36357a09..6537c3f5 100644 --- a/infini_train/include/device.h +++ b/infini_train/include/device.h @@ -70,6 +70,9 @@ class CudaDevice : public Device { nn::parallel::Rank rank() const override; + void ResetMemPoolHighWatermarks() const; + std::pair GetMemPoolPeakMB() const; + private: CudaDevice(int8_t index); diff --git a/infini_train/include/nn/modules/module.h b/infini_train/include/nn/modules/module.h index 9bc78bcc..02f62edf 100644 --- a/infini_train/include/nn/modules/module.h +++ b/infini_train/include/nn/modules/module.h @@ -11,6 +11,7 @@ namespace infini_train { class Tensor; class Device; +class Optimizer; } // namespace infini_train namespace infini_train::nn { @@ -53,7 +54,8 @@ class Module : public std::enable_shared_from_this { virtual std::vector> Forward(const std::vector> &input_tensors); virtual float TrainStep(const std::vector> &input_tensors, - const std::vector> &targets, const std::shared_ptr &loss_fn, + const std::vector> &targets, + const std::shared_ptr &optimizer, const std::shared_ptr &loss_fn, DataType dtype) { return 0.0f; }; diff --git a/infini_train/include/nn/parallel/distributed_data_parallel.h b/infini_train/include/nn/parallel/distributed_data_parallel.h index 6001a17a..ea3c99f3 100644 --- a/infini_train/include/nn/parallel/distributed_data_parallel.h +++ b/infini_train/include/nn/parallel/distributed_data_parallel.h @@ -3,6 +3,8 @@ #include #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h" +#include "infini_train/include/nn/parallel/param_and_grad_buffer.h" #include "infini_train/include/nn/parallel/reducer.h" namespace infini_train { @@ -14,13 +16,31 @@ namespace infini_train::nn::parallel { class DistributedDataParallel : public nn::Module { public: - DistributedDataParallel(std::shared_ptr module, int device_id, - const ReducerOptions &opts = ReducerOptions{}); + DistributedDataParallel(std::shared_ptr module, int thread_rank, + DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig()); std::vector> Forward(const std::vector> &input_tensors) override; + DistributedDataParallelConfig ddp_config() const { return ddp_config_; } + + const std::vector> ¶m_grad_buffers() const { return param_grad_buffers_; } + + const std::vector> &bucket_groups() const { return bucket_groups_; } + +private: + void BuildParamAndGradBuffers(); + void RegisterBackwardHooks(); + void OnGradReady(const std::shared_ptr ¶m); + private: std::shared_ptr reducer_ = nullptr; + + DistributedDataParallelConfig ddp_config_; + const ProcessGroup *ddp_pg_ = nullptr; + + std::vector> param_grad_buffers_; + std::vector> bucket_groups_; + std::unordered_map> param_to_bucket_group_; }; } // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/distributed_data_parallel_config.h b/infini_train/include/nn/parallel/distributed_data_parallel_config.h new file mode 100644 index 00000000..72a34143 --- /dev/null +++ b/infini_train/include/nn/parallel/distributed_data_parallel_config.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +namespace infini_train::nn::parallel { +namespace { +// Default bucket size in alignment with PyTorch +constexpr int kFirstBucketCapMB = 1; +constexpr int kNormalBucketCapMB = 25; +} // namespace + +class DistributedDataParallelConfig { +public: + // ====================================================== + // Reducer-related args + // Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html + // ====================================================== + // Max capacity for each bucket(in MB). + size_t first_bucket_cap_mb = kFirstBucketCapMB; + size_t normal_bucket_cap_mb = kNormalBucketCapMB; + + // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy. + bool gradient_as_bucket_view = true; + + // Whether to enable gradient bucketing. + bool gradient_bucketing_enabled = true; + + // ====================================================== + // DistributedOptimizer-related args + // Ref: + // https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/distributed_data_parallel_config.py + // ====================================================== + // Whether to enable DistributedOptimizer (ZeRO-1 equivalent). + // When set true: + // 1) Gradients/params are managed by ParamAndGradBuffer and reduced in groups. + // 2) The classic DDP reducer path is not used (i.e., disable reducer/bucketing in the DDP sense). + bool use_distributed_optimizer = false; + + // Whether to overlap gradient reduce-scatter/all-reduce with backward compute. + // In this case, grad reduce is triggered immediately when a grad is ready or till all grads are ready. + bool overlap_grad_reduce = true; + + // Whether to overlap parameter all-gather with forward compute. + bool overlap_param_gather = true; + + // Whether to average values inside collectives (divide by world size) instead of summing. + bool average_in_collective = true; + + // Whether to check NaNs/Infs/unusually large in gradients before collectives. + bool check_for_nan_in_grad = false; + bool check_for_large_grads = false; + + // Number of DistributedOptimizer instances. + // Multiple DistOpt is used for building hierarchical collective groups for param/grad. + int num_distributed_optimizer_instances = 1; + + // Maximum number of parameters in each ParamAndGradBucket. + // This is distinct from DDP Reducer's MB-based bucket caps. + size_t bucket_size_in_elements = std::numeric_limits::max(); + + // Whether to pad bucket sizes to improve NCCL bus bandwidth utilization. + bool pad_buckets_for_high_nccl_busbw = false; +}; +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/distributed_optimizer.h b/infini_train/include/nn/parallel/distributed_optimizer.h new file mode 100644 index 00000000..6368049b --- /dev/null +++ b/infini_train/include/nn/parallel/distributed_optimizer.h @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include +#include + +#include "infini_train/include/nn/parallel/param_and_grad_buffer.h" +#include "infini_train/include/optimizer.h" + +namespace infini_train::nn::parallel { + +class DistributedOptimizer final : public infini_train::Optimizer { +public: + DistributedOptimizer(OptimizerCreator inner_optimizer_creator, + const std::vector> &full_params, + const std::vector> &buffers, + const std::vector> &bucket_groups, + const ProcessGroup *dp_pg, size_t dp_world_size, size_t ddp_rank); + + void Step() override; + + void ZeroGrad(bool set_to_none = true) override; + + void StartGradSync(); + void FinishGradSync(); + + void StartParamSync(bool force_sync = false); + void FinishParamSync(bool skip_next_bucket_dispatch = false); + +private: + void BuildShardParamsAndBindGrads(); + +private: + // Inherit from DDP model + std::vector> param_grad_buffers_; + std::vector> bucket_groups_; + + // DP info + const ProcessGroup *dp_pg_; + size_t dp_world_size_; + size_t dp_rank_; + + // shard params + std::vector> shard_params_; + + // Base optimizer (SGD, Adam and etc.) + OptimizerCreator creator_; + std::shared_ptr base_optimizer_; +}; + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/param_and_grad_buffer.h b/infini_train/include/nn/parallel/param_and_grad_buffer.h new file mode 100644 index 00000000..e7485d88 --- /dev/null +++ b/infini_train/include/nn/parallel/param_and_grad_buffer.h @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "infini_train/include/datatype.h" +#include "infini_train/include/device.h" +#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h" + +namespace infini_train { +class Tensor; +namespace nn::parallel { +class ProcessGroup; +class Work; +} // namespace nn::parallel +} // namespace infini_train + +namespace infini_train::nn::parallel { +class ParamAndGradBucket { +public: + ParamAndGradBucket(const std::vector> ¶ms, const std::shared_ptr ¶m_data, + const std::shared_ptr &grad_data, size_t offset, size_t num_elements_unpadded, + float gradient_scaling_factor, size_t bucket_id); + + size_t bucket_id() const { return bucket_id_; } + + const std::vector> ¶ms() const { return params_; } + + const std::shared_ptr ¶m_data() const { return param_data_; } + + const std::shared_ptr &grad_data() const { return grad_data_; } + + size_t offset() const { return offset_; } + + size_t num_elements_unpadded() const { return num_elements_unpadded_; } + + float gradient_scaling_factor() const { return gradient_scaling_factor_; } + + bool GetTensorLocInBucket(const std::shared_ptr ¶meter, size_t &start_in_bucket, + size_t &end_in_bucket) const; + + void ScaleGradients(float scaling_factor); + +private: + int64_t bucket_id_ = 0; + std::vector> params_; + std::shared_ptr param_data_; + std::shared_ptr grad_data_; + + size_t offset_ = 0; + size_t num_elements_unpadded_ = 0; + float gradient_scaling_factor_ = 1.f; + + std::unordered_map> param_to_range_; +}; + +class ParamAndGradBucketGroup { +public: + ParamAndGradBucketGroup(const std::vector> &buckets, + const ProcessGroup *collective_pg, size_t process_group_size, + DistributedDataParallelConfig ddp_config); + + // Reset the state of this bucket group for the next training iter + void Reset(); + + // Register that the gradient of a parameter is ready, usually called in backward hook + // When all params in a bucket group are ready, will call StartGradSync() + void RegisterGradReady(const std::shared_ptr ¶meter); + + // Start grad reduce + void StartGradSync(); + + // Wait for gradient reduce to complete + void FinishGradSync(); + + // Start parameter all-gather + void StartParamSync(bool force_sync = false); + + // Wait for parameter all-gather to complete + void FinishParamSync(bool skip_next_bucket_dispatch = false); + + // TODO(zbl): For PP, set the next bucket group used for parameter all-gather. + void SetNextParamGatherBucketGroup(std::shared_ptr next_group); + + const std::vector> &buckets() const { return buckets_; } + + const DistributedDataParallelConfig &config() const { return ddp_config_; } + +private: + std::vector> buckets_; + const ProcessGroup *collective_pg_ = nullptr; + size_t collective_pg_size_ = 1; + int rank_in_collective_pg_ = -1; + DistributedDataParallelConfig ddp_config_; + + std::unordered_set params_; + std::unordered_set params_with_grad_; + + // TODO(zbl): Implement CoalescedWork for aggregate works + // According to Megatron-LM's _coalescing_manager + std::vector> grad_reduce_work_list_; + std::vector> param_gather_work_list_; + + std::shared_ptr next_param_gather_bucket_group_ = nullptr; + + std::vector>> param_buffer_shard_list_; + std::vector>> grad_buffer_shard_list_; + + bool is_last_microbatch_ = true; + + bool grad_reduce_dispatched_ = false; + bool param_gather_dispatched_ = false; +}; + +class ParamAndGradBuffer { +public: + ParamAndGradBuffer(const std::vector> ¶ms, DataType ¶m_dtype, DataType &grad_dtype, + const ProcessGroup *ddp_pg, DistributedDataParallelConfig ddp_config); + + DistributedDataParallelConfig ddp_config() const { return ddp_config_; } + + std::shared_ptr param_buffer() const { return param_buffer_; } + + std::shared_ptr grad_buffer() const { return grad_buffer_; } + + const ProcessGroup *ddp_pg() const { return ddp_pg_; } + + size_t ddp_world_size() const { return ddp_world_size_; } + + std::vector> buckets() const { return buckets_; } + + void ScaleGradients(float scaling_factor); + + void Reset(); + +private: + void BuildBuckets(DataType param_dtype, DataType grad_dtype); + +private: + DistributedDataParallelConfig ddp_config_; + std::vector> params_; + std::shared_ptr param_buffer_; + std::shared_ptr grad_buffer_; + + size_t numel_ = 0; + size_t numel_unpadded_ = 0; + + const ProcessGroup *ddp_pg_ = nullptr; + size_t ddp_world_size_ = 1; + std::vector> buckets_; + + std::vector> bucket_indices_; + // Param to (start, end, bucket_id) + std::unordered_map> param_index_map_; + // Param to bucket + std::unordered_map> param_bucket_map_; +}; + +std::vector> +PartitionBuckets(const std::vector> &buffers, bool force_single_bucket_group); + +} // namespace infini_train::nn::parallel diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index 2ddd0918..6eef969d 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -30,20 +30,18 @@ struct StageInfo { class PipelineParallel : public Module { public: PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, - const std::vector> &recv_shape, int rank, - const std::shared_ptr &optimizer, int device_id, int vpp); + const std::vector> &recv_shape, int rank, int device_id, int vpp); float TrainStep(const std::vector> &input, - const std::vector> &target, const std::shared_ptr &loss_fn, - DataType dtype); + const std::vector> &target, const std::shared_ptr &optimizer, + const std::shared_ptr &loss_fn, DataType dtype) override; static StageInfo GetStageInfo(int total_layers, int pp_size, int pp_rank, int chunks_per_stage = 1); std::vector> *mutable_chunks(); private: - void BuildPipelineStage(const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id, + void BuildPipelineStage(const std::vector> &recv_shape, int device_id, std::vector> &&chunks); void SetupSchedule(int num_micro_batches); diff --git a/infini_train/include/nn/parallel/pp/pipeline_schedule.h b/infini_train/include/nn/parallel/pp/pipeline_schedule.h index 3a8a9f75..053650d7 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_schedule.h +++ b/infini_train/include/nn/parallel/pp/pipeline_schedule.h @@ -7,6 +7,7 @@ namespace infini_train { class Tensor; +class Optimizer; namespace nn { class Module; } @@ -24,7 +25,7 @@ class PipelineSchedule { virtual ~PipelineSchedule() = default; float Step(std::shared_ptr input, std::shared_ptr target, - const std::shared_ptr &loss_fn, DataType dtype); + const std::shared_ptr &optimizer, const std::shared_ptr &loss_fn, DataType dtype); virtual float StepMicroBatches(const std::vector> &arg_mbs, const std::vector> &target_mbs, diff --git a/infini_train/include/nn/parallel/pp/pipeline_stage.h b/infini_train/include/nn/parallel/pp/pipeline_stage.h index 52a776ea..7a188cd4 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_stage.h +++ b/infini_train/include/nn/parallel/pp/pipeline_stage.h @@ -16,8 +16,8 @@ namespace infini_train::nn::parallel { class PipelineStage { public: - PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, - std::shared_ptr optimizer, int device_id, std::vector> &&chunks); + PipelineStage(int stage_index, int num_stages, const std::vector> &recv_shape, int device_id, + std::vector> &&chunks); std::vector> ForwardOneChunk(const std::vector> &inputs, int local_chunk_idx = 0); @@ -32,7 +32,6 @@ class PipelineStage { const Device *device() const; const std::vector> &recv_shape() const; - std::shared_ptr optimizer(); const std::vector> &chunks(); std::vector> *mutable_chunks(); @@ -43,7 +42,6 @@ class PipelineStage { int next_rank_ = -1; const Device *device_ = nullptr; std::vector> chunks_; - std::shared_ptr optimizer_ = nullptr; std::vector> recv_shape_; }; diff --git a/infini_train/include/nn/parallel/reducer.h b/infini_train/include/nn/parallel/reducer.h index f729f723..8cded98b 100644 --- a/infini_train/include/nn/parallel/reducer.h +++ b/infini_train/include/nn/parallel/reducer.h @@ -6,6 +6,7 @@ #include #include "infini_train/include/datatype.h" +#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h" #include "infini_train/include/nn/parallel/parallel_functional.h" namespace infini_train { @@ -21,9 +22,6 @@ class Work; namespace infini_train::nn::parallel { namespace { -// Default bucket size in alignment with PyTorch -constexpr int kFirstBucketCapMB = 1; -constexpr int kNormalBucketCapMB = 25; constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; } // namespace @@ -46,22 +44,6 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector const std::vector &bucket_size_limits, const std::vector &tensor_indices = {}); -struct ReducerOptions { - // Pack all Reducer-related args together - // Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html - - // Max capacity for each bucket(in MB) - size_t first_bucket_cap_mb = kFirstBucketCapMB; - size_t normal_bucket_cap_mb = kNormalBucketCapMB; - - // When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy - bool gradient_as_bucket_view = true; - - // Whether to enable gradient bucketing - // FIXME(zbl): should enable gradient bucketing by default - bool gradient_bucketing_enabled = true; -}; - // DDP Reducer that handles gradient bucketing in backward // ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/reducer.hpp class Reducer : public std::enable_shared_from_this { @@ -70,10 +52,10 @@ class Reducer : public std::enable_shared_from_this { * * @param parameters A list of parameters for this process's single model replica * @param bucket_indices The bucket assignment for this reducer - * @param opts Other options, see definition of ReducerOptions + * @param ddp_config DDP related options, see definition of DistributedDataParallelConfig */ explicit Reducer(std::vector> parameters, std::vector> bucket_indices, - const ReducerOptions &opts); + const DistributedDataParallelConfig ddp_config = DistributedDataParallelConfig()); // Attach PostAllReduceHooks to params void AttachHooksToParameters(); @@ -156,7 +138,7 @@ class Reducer : public std::enable_shared_from_this { std::atomic buckets_finished_{0}; std::shared_ptr comm_hook_ = nullptr; - ReducerOptions opts_; + DistributedDataParallelConfig ddp_config_; // Next bucket to be reduced // This is to make sure that all-reduce of buckets be launched in the order we expect diff --git a/infini_train/include/optimizer.h b/infini_train/include/optimizer.h index 81221908..fb0ae2d5 100644 --- a/infini_train/include/optimizer.h +++ b/infini_train/include/optimizer.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -8,11 +9,15 @@ namespace infini_train { class Tensor; } namespace infini_train { +class Optimizer; + +using OptimizerCreator = std::function(const std::vector> ¶ms)>; + class Optimizer { public: explicit Optimizer(const std::vector> ¶ms); - void ZeroGrad(bool set_to_none = true); + virtual void ZeroGrad(bool set_to_none = true); virtual void Step() = 0; @@ -27,6 +32,12 @@ class SGD : public Optimizer { void Step() override; + static OptimizerCreator Create(float learning_rate) { + return [learning_rate](const std::vector> ¶ms) { + return std::make_shared(params, learning_rate); + }; + } + private: const float learning_rate_ = 0.0; }; @@ -38,6 +49,13 @@ class Adam : public Optimizer { void Step() override; + static OptimizerCreator Create(float learning_rate = 1e-3, float beta1 = 0.9, float beta2 = 0.999, + float eps = 1e-8) { + return [=](const std::vector> ¶ms) { + return std::make_shared(params, learning_rate, beta1, beta2, eps); + }; + } + private: int64_t t_; const float learning_rate_; diff --git a/infini_train/include/tensor.h b/infini_train/include/tensor.h index b499b604..cf694fe8 100644 --- a/infini_train/include/tensor.h +++ b/infini_train/include/tensor.h @@ -63,6 +63,8 @@ class Tensor : public std::enable_shared_from_this { Tensor(const Tensor &tensor, size_t offset, const std::vector &dims); + void SetData(const Tensor &tensor, size_t offset, bool overwrite = false); + Tensor(const float *data, const std::vector &dims, DataType dtype, const Device *device); Tensor(const float *data, const std::vector &dims, DataType dtype) : Tensor(data, dims, dtype, DeviceManager::Instance()->GetDevice(DeviceType::kCPU, 0)) {} @@ -203,6 +205,9 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr grad() const; void set_grad(const std::shared_ptr &grad); + std::shared_ptr main_grad() const; + void set_main_grad(const std::shared_ptr &grad); + bool requires_grad() const; void set_requires_grad(bool requires_grad); @@ -231,6 +236,8 @@ class Tensor : public std::enable_shared_from_this { private: std::shared_ptr grad_ = nullptr; + // Points to a view in flat buffer constantly + std::shared_ptr main_grad_ = nullptr; bool requires_grad_ = false; bool is_leaf_ = true; std::shared_ptr grad_fn_ = nullptr; diff --git a/infini_train/src/device.cc b/infini_train/src/device.cc index 4271ff97..c16aeee7 100644 --- a/infini_train/src/device.cc +++ b/infini_train/src/device.cc @@ -11,6 +11,10 @@ #endif namespace infini_train { +namespace { +constexpr size_t kBytesPerMB = 1024ULL * 1024ULL; +} // namespace + Device::Device(DeviceType type, int8_t index) : type_(type), index_(index) { if (type_ == DeviceType::kCPU && index_ != 0) { LOG(FATAL) << "CPU device index should be 0"; @@ -73,6 +77,32 @@ CudaDevice::CudaDevice(int8_t index) CUBLAS_CHECK(cublasCreate(&cublas_handle_)); CUBLAS_CHECK(cublasSetStream(cublas_handle_, stream_)); } + +void CudaDevice::ResetMemPoolHighWatermarks() const { + SetDevice(); + cudaMemPool_t pool; + CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, index_)); + + cuuint64_t zero = 0; + // High watermark can only be reset to zero; non-zero is illegal. + CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &zero)); + CUDA_CHECK(cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &zero)); +} + +std::pair CudaDevice::GetMemPoolPeakMB() const { + SetDevice(); + cudaMemPool_t pool; + CUDA_CHECK(cudaDeviceGetDefaultMemPool(&pool, index_)); + + cuuint64_t used = 0; + CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrUsedMemHigh, &used)); + + cuuint64_t reserved = 0; + CUDA_CHECK(cudaMemPoolGetAttribute(pool, cudaMemPoolAttrReservedMemHigh, &reserved)); + + return std::make_pair(static_cast(used / kBytesPerMB), + static_cast(reserved / kBytesPerMB)); +} #endif // USE_CUDA const DeviceManager *DeviceManager::Instance() { diff --git a/infini_train/src/nn/parallel/distributed_data_parallel.cc b/infini_train/src/nn/parallel/distributed_data_parallel.cc index a25a7d16..ad76f13e 100644 --- a/infini_train/src/nn/parallel/distributed_data_parallel.cc +++ b/infini_train/src/nn/parallel/distributed_data_parallel.cc @@ -1,11 +1,14 @@ #include "infini_train/include/nn/parallel/distributed_data_parallel.h" +#include #include +#include #include #include "glog/logging.h" #include "infini_train/include/autograd/function_hook.h" +#include "infini_train/include/dispatcher.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/parallel_functional.h" #include "infini_train/include/nn/parallel/process_group.h" @@ -17,37 +20,147 @@ namespace { constexpr char kModuleName[] = "module"; } // namespace -DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int device_id, - const ReducerOptions &opts) { +DistributedDataParallel::DistributedDataParallel(std::shared_ptr module, int thread_rank, + const DistributedDataParallelConfig ddp_config) + : ddp_config_(ddp_config), + ddp_pg_(ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(thread_rank))) { for (auto ¶m : module->Parameters()) { auto device = param->GetDevice(); - CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module"; - if (!opts.gradient_bucketing_enabled) { - auto ddp_pg - = ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank())); + CHECK_EQ(device->Index(), thread_rank) << "All parameters must be on the same device as the module"; + if (!ddp_config.gradient_bucketing_enabled && !ddp_config.use_distributed_optimizer) { auto hook = std::make_unique( - function::ReduceOpType::kAvg, ddp_pg); + function::ReduceOpType::kAvg, ddp_pg_); param->RegisterPostAccumulateGradHook(std::move(hook)); } } for (auto &buffer : module->Buffers()) { - CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module"; + CHECK_EQ(buffer->GetDevice()->Index(), thread_rank) << "All buffers must be on the same device as the module"; } modules_[kModuleName] = std::move(module); - if (opts.gradient_bucketing_enabled) { + if (ddp_config.use_distributed_optimizer) { + BuildParamAndGradBuffers(); + RegisterBackwardHooks(); + } else if (ddp_config.gradient_bucketing_enabled) { // Bucket Assignment auto params = modules_[kModuleName]->Parameters(); - const size_t first_cap_bytes = opts.first_bucket_cap_mb * kBytesPerMB; - const size_t normal_cap_bytes = opts.normal_bucket_cap_mb * kBytesPerMB; + const size_t first_cap_bytes = ddp_config.first_bucket_cap_mb * kBytesPerMB; + const size_t normal_cap_bytes = ddp_config.normal_bucket_cap_mb * kBytesPerMB; std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits); - reducer_ = std::make_shared(params, bucket_indices, opts); + reducer_ = std::make_shared(params, bucket_indices, ddp_config); reducer_->AttachHooksToParameters(); } } +void DistributedDataParallel::BuildParamAndGradBuffers() { + // (param_dtype, grad_dtype) + using DTypePair = std::pair; + std::map>> dtype_to_params; + + for (auto param : modules_[kModuleName]->Parameters()) { + if (!param->requires_grad()) { + continue; + } + auto param_dtype = param->Dtype(); + auto grad_dtype = param->grad() ? param->grad()->Dtype() : param_dtype; + dtype_to_params[{param_dtype, grad_dtype}].push_back(param); + } + + param_grad_buffers_.clear(); + param_grad_buffers_.reserve(dtype_to_params.size()); + + for (auto &kv : dtype_to_params) { + auto [param_dtype, grad_dtype] = kv.first; + auto param_list = kv.second; + + if (param_list.empty()) { + continue; + } + + auto buffer = std::make_shared(param_list, param_dtype, grad_dtype, ddp_pg_, ddp_config_); + + param_grad_buffers_.push_back(buffer); + } + + // TODO(zbl): option for disable bucketing + bucket_groups_ = PartitionBuckets(param_grad_buffers_, /*force_single_bucket_group=*/false); + + if (ddp_config_.use_distributed_optimizer && ddp_config_.overlap_param_gather) { + auto num_bucket_groups = bucket_groups_.size(); + for (auto i = num_bucket_groups - 1; i > 0; --i) { + bucket_groups_[i]->SetNextParamGatherBucketGroup(bucket_groups_[i - 1]); + } + } + + param_to_bucket_group_.clear(); + for (auto &group : bucket_groups_) { + for (auto &bucket : group->buckets()) { + for (auto ¶m : bucket->params()) { + auto inserted = param_to_bucket_group_.emplace(param.get(), group).second; + if (!inserted) { + LOG(FATAL) << "Parameter appears in more than one bucket group."; + } + } + } + } + + LOG(INFO) << "DDP BuildParamAndGradBuffers: " + << "dtype_groups=" << dtype_to_params.size() << ", param_grad_buffers=" << param_grad_buffers_.size() + << ", bucket_groups=" << bucket_groups_.size(); +} + +void DistributedDataParallel::RegisterBackwardHooks() { + class DDPPostAccumulateHook final : public autograd::PostAccumulateGradHook { + public: + DDPPostAccumulateHook(DistributedDataParallel *ddp, const std::weak_ptr param) + : ddp_(ddp), param_(param) {} + + void operator()(const std::shared_ptr &) override { + if (auto param = param_.lock()) { + ddp_->OnGradReady(param); + } + } + + private: + DistributedDataParallel *ddp_; + std::weak_ptr param_; + }; + + auto &module = modules_.at(kModuleName); + for (auto ¶m : module->Parameters()) { + if (!param->requires_grad()) { + continue; + } + + auto hook = std::make_unique(this, param); + param->RegisterPostAccumulateGradHook(std::move(hook)); + } +} + +void DistributedDataParallel::OnGradReady(const std::shared_ptr ¶m) { + auto it = param_to_bucket_group_.find(param.get()); + if (it != param_to_bucket_group_.end()) { + CHECK(param->requires_grad()); + if (ddp_config_.overlap_grad_reduce) { + CHECK(param->grad()) << "param.grad being None is not safe when overlap_grad_reduce is True"; + } + + if (param->grad()) { + // Add to main_grad(buffer) + auto kernel = Dispatcher::Instance().GetKernel({param->GetDevice()->Type(), "AccumulateGrad"}); + kernel.Call(param->grad(), 1.f, param->main_grad()); + } + // Can safely set grad to null because grad has already been added to main_grad(buffer) + param->set_grad(nullptr); + + if (ddp_config_.overlap_grad_reduce) { + it->second->RegisterGradReady(param); + } + } +} + std::vector> DistributedDataParallel::Forward(const std::vector> &input_tensors) { auto outputs = modules_[kModuleName]->Forward(input_tensors); diff --git a/infini_train/src/nn/parallel/distributed_optimizer.cc b/infini_train/src/nn/parallel/distributed_optimizer.cc new file mode 100644 index 00000000..1a185455 --- /dev/null +++ b/infini_train/src/nn/parallel/distributed_optimizer.cc @@ -0,0 +1,134 @@ +#include "infini_train/include/nn/parallel/distributed_optimizer.h" + +#include "glog/logging.h" + +#include "infini_train/include/device.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +namespace { +std::shared_ptr GetShardView(const std::shared_ptr &buffer, size_t world_size, size_t rank) { + + CHECK(buffer); + CHECK_GT(world_size, 0); + CHECK_LT(rank, world_size); + CHECK_EQ(buffer->NumElements() % world_size, 0); + + const size_t shard_numel = buffer->NumElements() / world_size; + const size_t offset_bytes = shard_numel * rank * kDataTypeToSize.at(buffer->Dtype()); + + return std::make_shared(*buffer, offset_bytes, std::vector{static_cast(shard_numel)}); +} + +} // namespace + +DistributedOptimizer::DistributedOptimizer(OptimizerCreator creator, + const std::vector> &full_params, + const std::vector> &buffers, + const std::vector> &bucket_groups, + const ProcessGroup *dp_pg, size_t dp_world_size, size_t dp_rank) + : Optimizer(full_params), param_grad_buffers_(buffers), bucket_groups_(bucket_groups), dp_pg_(dp_pg), + dp_world_size_(dp_world_size), dp_rank_(dp_rank), creator_(std::move(creator)) { + + CHECK(dp_pg_); + CHECK(dp_world_size_ > 1) << "DistributedOptimizer: dp_world_size must be greater than 1."; + + BuildShardParamsAndBindGrads(); + + // Build base optimizer + base_optimizer_ = creator_(shard_params_); + CHECK(base_optimizer_) << "DistributedOptimizer: failed to create base optimizer."; +} + +void DistributedOptimizer::BuildShardParamsAndBindGrads() { + shard_params_.clear(); + + for (const auto &group : bucket_groups_) { + for (const auto &bucket : group->buckets()) { + + auto bucket_param = bucket->param_data(); + auto bucket_grad = bucket->grad_data(); + + CHECK(bucket_param) << "DistributedOptimizer requires param buffer."; + CHECK(bucket_grad) << "DistributedOptimizer requires grad buffer."; + + CHECK_EQ(bucket_param->NumElements() % dp_world_size_, 0); + const size_t bucket_shard_numel = bucket_param->NumElements() / dp_world_size_; + const size_t bucket_shard_start = dp_rank_ * bucket_shard_numel; + const size_t bucket_shard_end = bucket_shard_start + bucket_shard_numel; + + // Iterate param in bucket, build each param(or param_shard) seperately + for (const auto ¶m : bucket->params()) { + size_t param_start_in_bucket = 0, param_end_in_bucket = 0; + auto found = bucket->GetTensorLocInBucket(param, param_start_in_bucket, param_end_in_bucket); + CHECK(found) << "DistributedOptimizer: param not found in bucket mapping."; + + const size_t local_start = std::max(param_start_in_bucket, bucket_shard_start); + const size_t local_end = std::min(param_end_in_bucket, bucket_shard_end); + if (local_end <= local_start) { + // this rank owns no elements for this param + continue; + } + + const size_t piece_numel = local_end - local_start; + CHECK_GT(piece_numel, 0); + + const size_t param_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_param->Dtype()); + const size_t grad_piece_offset_bytes = local_start * kDataTypeToSize.at(bucket_grad->Dtype()); + + auto param_piece = std::make_shared(*bucket_param, param_piece_offset_bytes, + std::vector{static_cast(piece_numel)}); + + auto grad_piece = std::make_shared(*bucket_grad, grad_piece_offset_bytes, + std::vector{static_cast(piece_numel)}); + + param_piece->set_grad(grad_piece); + shard_params_.push_back(param_piece); + } + } + } + + CHECK(!shard_params_.empty()) << "DistributedOptimizer: this DP rank owns no param pieces. " + << "Check bucket padding/divisibility and param bucketing order."; +} + +void DistributedOptimizer::StartGradSync() { + for (auto &group : bucket_groups_) { group->StartGradSync(); } +} + +void DistributedOptimizer::FinishGradSync() { + for (auto &group : bucket_groups_) { group->FinishGradSync(); } +} + +void DistributedOptimizer::StartParamSync(bool force_sync) { + for (auto &group : bucket_groups_) { group->StartParamSync(force_sync); } +} + +void DistributedOptimizer::FinishParamSync(bool skip_next_bucket_dispatch) { + for (auto &group : bucket_groups_) { group->FinishParamSync(skip_next_bucket_dispatch); } +} + +void DistributedOptimizer::ZeroGrad(bool set_to_none) { + // Zero main_grad buffer and clear BucketGroup state + for (auto &buffer : param_grad_buffers_) { buffer->Reset(); } + for (auto &group : bucket_groups_) { group->Reset(); } + // Call base class's method: Zero each param's grad to guarantee consistency + infini_train::Optimizer::ZeroGrad(set_to_none); +} + +void DistributedOptimizer::Step() { + // 1. Ensure grads are synced + FinishGradSync(); + + // 2. Base optimizer step on owned param pieces + CHECK(base_optimizer_) << "DistributedOptimizer: base optimizer is null."; + base_optimizer_->Step(); + + // 3. Gather updated param shards back to full params + StartParamSync(/*force_sync=*/false); + // FIXME(zbl): Call sync before param is actually used in next step + FinishParamSync(/*skip_next_bucket_dispatch=*/true); +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/param_and_grad_buffer.cc b/infini_train/src/nn/parallel/param_and_grad_buffer.cc new file mode 100644 index 00000000..6297fe4c --- /dev/null +++ b/infini_train/src/nn/parallel/param_and_grad_buffer.cc @@ -0,0 +1,581 @@ +#include "infini_train/include/nn/parallel/param_and_grad_buffer.h" + +#include +#include +#include + +#include "glog/logging.h" + +#include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/distributed_data_parallel_config.h" +#include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" +#include "infini_train/include/nn/parallel/reduce_op_type.h" +#include "infini_train/include/nn/parallel/work.h" +#include "infini_train/include/tensor.h" + +namespace infini_train::nn::parallel { + +namespace { +constexpr size_t kParamStartAlignElements = 64; +constexpr size_t kBucketEndAlignElements = 128; + +inline size_t PadTo(size_t value, size_t alignment) { + if (alignment == 0) { + return value; + } + size_t remainder = value % alignment; + return remainder == 0 ? value : value + (alignment - remainder); +} + +std::shared_ptr AllocateFlatBuffer(size_t num_elements, DataType data_type, const Device *device) { + std::vector dims = {static_cast(num_elements)}; + // TODO(zbl): replace with united allocation when memory pool is available + return std::make_shared(dims, data_type, device); +} + +std::shared_ptr GetBufferView(const std::shared_ptr buffer, size_t start_in_elements, + const std::vector &dims) { + return std::make_shared(*buffer, start_in_elements * kDataTypeToSize.at(buffer->Dtype()), dims); +}; + +std::vector> ShardBuffer(const std::shared_ptr buffer, size_t ddp_world_size) { + CHECK_EQ(buffer->NumElements() % ddp_world_size, 0); + size_t shard_size = buffer->NumElements() / ddp_world_size; + std::vector> sharded_buffer; + for (auto i = 0; i < ddp_world_size; ++i) { + sharded_buffer.push_back( + GetBufferView(buffer, i * shard_size, std::vector{static_cast(shard_size)})); + } + return sharded_buffer; +} + +} // namespace + +ParamAndGradBucket::ParamAndGradBucket(const std::vector> ¶ms, + const std::shared_ptr ¶m_data, + const std::shared_ptr &grad_data, size_t offset, + size_t num_elements_unpadded, float gradient_scaling_factor, size_t bucket_id) + : bucket_id_(bucket_id), params_(std::move(params)), param_data_(std::move(param_data)), + grad_data_(std::move(grad_data)), offset_(offset), num_elements_unpadded_(num_elements_unpadded), + gradient_scaling_factor_(gradient_scaling_factor) { + size_t current_offset = 0; + for (const auto ¶m : params_) { + auto numel = param->NumElements(); + param_to_range_.emplace(param.get(), std::make_pair(current_offset, current_offset + numel)); + current_offset += numel; + } +} + +bool ParamAndGradBucket::GetTensorLocInBucket(const std::shared_ptr ¶meter, size_t &start_in_bucket, + size_t &end_in_bucket) const { + const auto iterator = param_to_range_.find(parameter.get()); + if (iterator == param_to_range_.end()) { + return false; + } + start_in_bucket = iterator->second.first; + end_in_bucket = iterator->second.second; + return true; +} + +void ParamAndGradBucket::ScaleGradients(float scaling_factor) { + if (!grad_data_ || scaling_factor == 1.f) { + return; + } + + // FIXME(zbl): should perform in-place multiply + // grad_data_ *= scaling_factor; + LOG(FATAL) << "ParamAndGradBucket: Should not arrive here"; +} + +ParamAndGradBucketGroup::ParamAndGradBucketGroup(const std::vector> &buckets, + const ProcessGroup *collective_pg, size_t process_group_size, + DistributedDataParallelConfig ddp_config) + : buckets_(std::move(buckets)), collective_pg_(collective_pg), collective_pg_size_(process_group_size), + ddp_config_(ddp_config) { + // TODO(zbl): support hierarchical gradient sync in distopt + CHECK(ddp_config.num_distributed_optimizer_instances == 1) + << "ParamAndGradBucketGroup: Multi-instance DistributedOptimizer is not supported yet."; + + for (const auto &bucket : buckets_) { + for (const auto ¶m : bucket->params()) { params_.insert(param.get()); } + } + if (rank_in_collective_pg_ == -1) { + auto param = *params_.begin(); + // FIXME(zbl): get correct rank in multi-node settings + rank_in_collective_pg_ + = collective_pg_->GetGroupRank(dynamic_cast(param->GetDevice())->rank().thread_rank()); + } + + param_buffer_shard_list_.resize(buckets_.size()); + grad_buffer_shard_list_.resize(buckets_.size()); +} + +void ParamAndGradBucketGroup::Reset() { + params_with_grad_.clear(); + grad_reduce_work_list_.clear(); + param_gather_work_list_.clear(); + is_last_microbatch_ = true; + grad_reduce_dispatched_ = false; + param_gather_dispatched_ = false; +} + +void ParamAndGradBucketGroup::RegisterGradReady(const std::shared_ptr ¶meter) { + if (!ddp_config_.overlap_grad_reduce) { + LOG(WARNING) + << "ParamAndGradBucketGroup: RegisterGradReady() should only be called when overlap_grad_reduce is " + "True. Skipping here."; + return; + } + + // Only register grads as ready when processing the last microbatch + if (is_last_microbatch_) { + if (!parameter || params_.find(parameter.get()) == params_.end()) { + return; + } + + const bool inserted = params_with_grad_.insert(parameter.get()).second; + if (!inserted) { + LOG(FATAL) << "ParamAndGradBucketGroup: RegisterGradReady() was called twice for the same parameter in a " + "bucket group."; + return; + } + + if (params_with_grad_.size() == params_.size()) { + // All param grads are ready in this group, trigger grad sync + StartGradSync(); + } + } +} + +void ParamAndGradBucketGroup::StartGradSync() { + if (!collective_pg_) { + LOG(FATAL) << "ParamAndGradBucketGroup: StartGradSync() called with null collective_pg_."; + return; + } + + if (grad_reduce_dispatched_) { + return; + } + if (!grad_reduce_work_list_.empty()) { + grad_reduce_dispatched_ = true; + return; + } + + // TODO(zbl): Check NaN/Inf/too large in grad (options in DistributedDataParallelConfig) + + for (auto bucket : buckets_) { + if (bucket->gradient_scaling_factor() != 1.f) { + bucket->ScaleGradients(bucket->gradient_scaling_factor()); + } + } + + auto reduce_op = ddp_config_.average_in_collective ? function::ReduceOpType::kAvg : function::ReduceOpType::kSum; + auto async_op = ddp_config_.overlap_grad_reduce && (ddp_config_.num_distributed_optimizer_instances == 1); + + for (auto i = 0; i < buckets_.size(); ++i) { + auto bucket = buckets_[i]; + std::shared_ptr grad_buffer = bucket->grad_data(); + if (!grad_buffer) { + continue; + } + + if (ddp_config_.use_distributed_optimizer) { + if (grad_buffer_shard_list_[i].empty()) { + grad_buffer_shard_list_[i] = ShardBuffer(grad_buffer, collective_pg_size_); + } + auto local_data_view = grad_buffer_shard_list_[i][rank_in_collective_pg_]; + grad_reduce_work_list_.push_back( + collective_pg_->ReduceScatter(local_data_view, grad_buffer, reduce_op, async_op)); + } else { + // NOTE(zbl): Should not arrive here because Reducer-related logic is activated when not using DistOpt + grad_reduce_work_list_.push_back(collective_pg_->AllReduce(grad_buffer, reduce_op, async_op)); + } + } + + grad_reduce_dispatched_ = true; +} + +void ParamAndGradBucketGroup::FinishGradSync() { + if (!grad_reduce_dispatched_) { + StartGradSync(); + } + + if (!ddp_config_.overlap_grad_reduce) { + // Assume reduce ops are synced and no work needs to be resolved + grad_reduce_work_list_.clear(); + grad_reduce_dispatched_ = false; + return; + } + + CHECK(!grad_reduce_work_list_.empty()) + << "ParamAndGradBucketGroup: Communication call has not been issued for this bucket(" + << params_with_grad_.size() << "/" << params_.size() << " params have grad available)"; + + for (auto work : grad_reduce_work_list_) { work->WaitNonBlocking(); } + grad_reduce_work_list_.clear(); + grad_reduce_dispatched_ = false; +} + +void ParamAndGradBucketGroup::StartParamSync(bool force_sync) { + CHECK(ddp_config_.use_distributed_optimizer); + + if (!collective_pg_) { + LOG(ERROR) << "ParamAndGradBucketGroup: StartParamSync called with null collective_pg_."; + return; + } + + if (force_sync) { + // force synchronous collective regardless of other settings + for (auto work : param_gather_work_list_) { work->WaitNonBlocking(); } + param_gather_work_list_.clear(); + return; + } else { + CHECK(param_gather_work_list_.empty()); + } + + auto async_op = ddp_config_.overlap_param_gather && (!force_sync); + + for (auto i = 0; i < buckets_.size(); ++i) { + auto bucket = buckets_[i]; + std::shared_ptr param_buffer = bucket->param_data(); + if (!param_buffer) { + continue; + } + + if (param_buffer_shard_list_[i].empty()) { + param_buffer_shard_list_[i] = ShardBuffer(param_buffer, collective_pg_size_); + } + auto local_data_view = param_buffer_shard_list_[i][rank_in_collective_pg_]; + param_gather_work_list_.push_back(collective_pg_->AllGather(param_buffer, local_data_view, async_op)); + } + + param_gather_dispatched_ = true; +} + +void ParamAndGradBucketGroup::FinishParamSync(bool skip_next_bucket_dispatch) { + if (!ddp_config_.use_distributed_optimizer || !ddp_config_.overlap_param_gather) { + return; + } + + if (!param_gather_dispatched_) { + StartParamSync(); + } + + if (!param_gather_work_list_.empty()) { + for (auto work : param_gather_work_list_) { work->WaitNonBlocking(); } + param_gather_work_list_.clear(); + param_gather_dispatched_ = false; + + if (next_param_gather_bucket_group_ && !skip_next_bucket_dispatch) { + if (next_param_gather_bucket_group_->param_gather_dispatched_) { + LOG(WARNING) + << "ParamAndGradBucketGroup: The next bucket's parameter all-gather operation has already been " + "dispatched. This may be caused by a mismatch between the order of parameter registration and " + "forward pass execution, which will hurt the communication - computation overlap performance."; + } else { + next_param_gather_bucket_group_->StartParamSync(); + } + } + } +} + +void ParamAndGradBucketGroup::SetNextParamGatherBucketGroup(std::shared_ptr next_group) { + next_param_gather_bucket_group_ = next_group; +} + +ParamAndGradBuffer::ParamAndGradBuffer(const std::vector> ¶ms, DataType ¶m_dtype, + DataType &grad_dtype, const ProcessGroup *ddp_pg, + DistributedDataParallelConfig ddp_config) + : params_(std::move(params)), ddp_pg_(std::move(ddp_pg)), ddp_config_(ddp_config) { + if (ddp_pg_) { + ddp_world_size_ = global::GetDataParallelSize(); + } + + BuildBuckets(param_dtype, grad_dtype); +} + +void ParamAndGradBuffer::BuildBuckets(DataType param_dtype, DataType grad_dtype) { + // Pack parameters in buffer, allocate memory, and build buckets. + + // Param start must be multiple of 64 + auto PadParamStartIfNeeded = [&](size_t start) -> size_t { + if (ddp_config_.use_distributed_optimizer) { + // According to Megatron-LM, make sure each param starts at 128B aligned address (by default align to 64 + // elements for precision >=16-bit) + return PadTo(start, kParamStartAlignElements); + } + return start; + }; + + // Bucket size shoule be multiple of ddp size and 128 (sweet spot for NCCL) + auto PadBucketEndIfNeeded = [&](size_t bucket_end_index) -> size_t { + if (ddp_config_.use_distributed_optimizer) { + // According to Megatron-LM, ensure that all buckets start at a memory address that is 256B + // aligned(128 values since params and grads use >= 16-bit precision) + size_t lcm_val = std::lcm(ddp_world_size_, kBucketEndAlignElements); + if (ddp_config_.pad_buckets_for_high_nccl_busbw) { + // According to Megatron-LM, when the bucket size is divisible by a large power of 2 (2^16), + // NCCL collectives can have high bus bandwidth at large DP counts + lcm_val = std::lcm(lcm_val, static_cast(1u << 16)); + } + return PadTo(bucket_end_index, lcm_val); + } + return bucket_end_index; + }; + + size_t param_start_index = 0; + size_t param_end_index = 0; + size_t bucket_start_index = 0; + size_t bucket_end_index = 0; + size_t bucket_id = 0; + std::vector> bucket_params; + std::vector per_bucket_numel_unpadded; + + auto UpdateBucketMetadata = [&](size_t param_end_index) -> size_t { + // calculate numel when bucket is unpadded + const size_t numel_unpadded_bucket = param_end_index - bucket_start_index; + per_bucket_numel_unpadded.push_back(numel_unpadded_bucket); + + // calculate bucket_end_index with padding, save the range of bucket in buffer + size_t bucket_end_index = PadBucketEndIfNeeded(param_end_index); + bucket_indices_.push_back({bucket_start_index, bucket_end_index}); + + // move ptr to next bucket + bucket_start_index = bucket_end_index; + bucket_params.clear(); + ++bucket_id; + return bucket_end_index; + }; + + // 1. Pack params into buffer, in backprop order + for (auto it = params_.rbegin(); it != params_.rend(); ++it) { + const auto ¶m = *it; + param_start_index = PadParamStartIfNeeded(param_start_index); + + // TODO(zbl): check whether there are params that need its own bucket + // if (DoesParamRequiresNewBucket(param)) { ... } + + param_end_index = param_start_index + param->NumElements(); + param_index_map_[param.get()] = {param_start_index, param_end_index, bucket_id}; + bucket_params.push_back(param); + + if ((param_end_index - bucket_start_index) >= ddp_config_.bucket_size_in_elements) { + // If current bucket is full, then wrap up + // NOTE(zbl): Actual bucket size might be larger than bucket size + bucket_end_index = UpdateBucketMetadata(param_end_index); + param_start_index = bucket_end_index; + } else { + param_start_index = param_end_index; + } + } + + // If the last bucket is not full, still wrap it up + if (!bucket_params.empty()) { + bucket_end_index = UpdateBucketMetadata(param_end_index); + } + + // numel with padding = bucket end + numel_ = bucket_end_index; + // numel without padding + numel_unpadded_ = std::accumulate(per_bucket_numel_unpadded.begin(), per_bucket_numel_unpadded.end(), + static_cast(0), std::plus()); + + CHECK(numel_unpadded_ <= numel_); + if (ddp_config_.use_distributed_optimizer) { + // numel must be multiple of ddp size (so that reduce-scatter could easily shard the buffer among ranks) + CHECK_EQ(numel_ % ddp_world_size_, 0); + } else { + CHECK_EQ(numel_, numel_unpadded_); + } + + // 2. Allocate buffer + auto device = params_.front()->GetDevice(); + if (ddp_config_.use_distributed_optimizer) { + param_buffer_ = AllocateFlatBuffer(numel_, param_dtype, device); + } else { + // No param buffer needed if optimzer is not distributed + param_buffer_.reset(); + } + grad_buffer_ = AllocateFlatBuffer(numel_, grad_dtype, device); + + LOG(INFO) << "ParamAndGradBuffer: numel_unpadded=" << numel_unpadded_ << ", numel (padded)=" << numel_; + + // 3. Build buckets, and map param/grad to views of buffers + bucket_params.clear(); + bucket_start_index = 0; + size_t current_bucket_id = 0; + + // Helper function to create ParamAndGradBucket object + auto NewBucket + = [&](const std::vector> &bucket_params, size_t start_index, size_t end_index, + size_t num_elements_unpadded, size_t bucket_id) -> std::shared_ptr { + if (ddp_config_.use_distributed_optimizer) { + CHECK_EQ(start_index % ddp_world_size_, 0); + CHECK_EQ(end_index % ddp_world_size_, 0); + CHECK_EQ(bucket_indices_.at(bucket_id).first, start_index); + CHECK_EQ(bucket_indices_.at(bucket_id).second, end_index); + } + + std::shared_ptr bucket_param_view; + if (param_buffer_) { + bucket_param_view = GetBufferView(param_buffer_, start_index, + std::vector{static_cast(end_index - start_index)}); + } + std::shared_ptr bucket_grad_view = GetBufferView( + grad_buffer_, start_index, std::vector{static_cast(end_index - start_index)}); + + // FIXME(zbl): Use default for now + float gradient_scaling_factor = 1.0f; + auto bucket + = std::make_shared(bucket_params, bucket_param_view, bucket_grad_view, start_index, + num_elements_unpadded, gradient_scaling_factor, bucket_id); + + for (auto param : bucket_params) { + CHECK(param_bucket_map_.find(param.get()) == param_bucket_map_.end()) + << "Parameter appears in multiple buckets."; + param_bucket_map_[param.get()] = bucket; + } + + return std::move(bucket); + }; + + // Iterate params in backprop order, build ParamAndGradBucket object + for (auto it = params_.rbegin(); it != params_.rend(); ++it) { + const auto ¶m = *it; + std::tie(param_start_index, param_end_index, bucket_id) = param_index_map_.at(param.get()); + + // Remap param/grad pointers + if (param_buffer_) { + // FIXME(zbl): change tensor buffer + param->SetData(*param_buffer_, param_start_index * kDataTypeToSize.at(param_buffer_->Dtype()), true); + } + + auto grad_view = GetBufferView(grad_buffer_, param_start_index, param->Dims()); + param->set_main_grad(grad_view); + + if (current_bucket_id != bucket_id) { + const auto &range = bucket_indices_.at(current_bucket_id); + CHECK_EQ(range.first, bucket_start_index) << "Bucket start mismatch."; + + bucket_end_index = range.second; + buckets_.push_back(NewBucket(bucket_params, bucket_start_index, bucket_end_index, + per_bucket_numel_unpadded[current_bucket_id], current_bucket_id)); + + bucket_start_index = bucket_end_index; + bucket_params.clear(); + + CHECK_EQ(current_bucket_id + 1, buckets_.size()); + CHECK_EQ(current_bucket_id + 1, bucket_id); + current_bucket_id = bucket_id; + } + bucket_params.push_back(param); + } + + // If the last bucket is not full, still wrap it up + if (!bucket_params.empty()) { + const auto &range = bucket_indices_.at(current_bucket_id); + CHECK_EQ(range.first, bucket_start_index) << "Last bucket start mismatch."; + buckets_.push_back(NewBucket(bucket_params, bucket_start_index, bucket_end_index, + per_bucket_numel_unpadded[current_bucket_id], current_bucket_id)); + } +} + +void ParamAndGradBuffer::ScaleGradients(float scaling_factor) { + if (!grad_buffer_ || scaling_factor == 1.f) { + return; + } + + // FIXME(zbl): should perform in-place multiply + // grad_data_ *= scaling_factor; + LOG(FATAL) << "Should not arrive here"; +} + +void ParamAndGradBuffer::Reset() { + if (!grad_buffer_) { + return; + } + grad_buffer_->Fill(0.f); +} + +/*** + Automatically regroup the buckets of input buffers and return a list of bucket groups. + + In some scenarios, we need to put buckets from different buffers into a group so that their + communication can be aggregated. + + For example, when there are both fp8 weights and bf16 biases in the model and virtual + pipeline parallelism is enabled, each model chunk will have an fp8 bucket and a bf16 bucket, + which doubles the number of communication kernels, and because of the use of + CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back communications will prevent the + overlap of communication kernels with computation kernels. + + The grouping strategy is: + 1. If force_single_bucket_group is True, put all buckets across all buffers into a single + bucket group. + 2. If force_single_bucket_group is False, when there is no fp8 buffer in the input buffers, + let each bucket group have only one bucket. + 3. If force_single_bucket_group is False, when using fp8 params, merge all non-fp8 buckets + into the last fp8 bucket group. + - Since the non-fp8 parameters (typically the biases of various layers) are relatively + small, they are likely to be grouped into a single non-fp8 bucket. + - The fp8 buckets start from the end of the model, i.e., the first bucket corresponds to + the end of the model, while the last bucket corresponds to the beginning. + - If we combine the non-fp8 bucket with the first fp8 bucket, we cannot initiate the + reduce-scatter to synchronize gradients after the backward pass at the end of the model + has completed. This is because we need to wait for the non-fp8 params from the beginning + layers to obtain their gradients. + - Combining the non-fp8 bucket with the last fp8 bucket can help avoid this issue. + + Args: + buffers (list): list of input buffers. + single_bucket_group_per_buffer (bool, optional): force group all buckets in each buffer + into a single bucket group. +***/ +std::vector> +PartitionBuckets(const std::vector> &buffers, bool force_single_bucket_group) { + std::vector> bucket_groups; + + if (buffers.empty()) { + return bucket_groups; + } + + // Case 1: Put all buckets into a single bucket group if force_single_bucket_group is True. + if (force_single_bucket_group) { + std::vector> all_buckets; + auto ddp_config = buffers.front()->ddp_config(); + auto ddp_pg = buffers.front()->ddp_pg(); + auto ddp_world_size = buffers.front()->ddp_world_size(); + + for (const auto &buffer : buffers) { + // TODO(zbl): override == for ddp config + // CHECK(buffer->ddp_config() == ddp_config) << "PartitionBuckets: buffers have different ddp_config."; + CHECK(buffer->ddp_pg() == ddp_pg) << "PartitionBuckets: buffers have different ddp_pg."; + CHECK(buffer->ddp_world_size() == ddp_world_size) + << "PartitionBuckets: buffers have different ddp_world_size."; + + all_buckets.insert(all_buckets.end(), buffer->buckets().begin(), buffer->buckets().end()); + } + + bucket_groups.push_back( + std::make_shared(all_buckets, ddp_pg, ddp_world_size, ddp_config)); + return bucket_groups; + } + + // Case 2: When there is no fp8 buffer in the input buffers, let each bucket group have only one bucket. + for (const auto &buffer : buffers) { + const auto &buffer_buckets = buffer->buckets(); + for (const auto &bucket : buffer_buckets) { + std::vector> single_bucket_list; + single_bucket_list.push_back(bucket); + bucket_groups.push_back(std::make_shared( + single_bucket_list, buffer->ddp_pg(), buffer->ddp_world_size(), buffer->ddp_config())); + } + } + + // TODO(zbl): Support fp8 params + // Case 3: When using fp8 params, merge all non-fp8 buckets into the last fp8 bucket group. + return bucket_groups; +} + +} // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index bbdceaea..34d2272a 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -9,7 +9,6 @@ #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/nn/parallel/pp/pipeline_schedule.h" #include "infini_train/include/nn/parallel/pp/pipeline_stage.h" -#include "infini_train/include/optimizer.h" namespace infini_train::nn::parallel { namespace { @@ -18,11 +17,9 @@ constexpr char kModuleName[] = "module"; thread_local int pp_rank = 0; -void PipelineParallel::BuildPipelineStage(const std::shared_ptr &optimizer, - const std::vector> &recv_shape, int device_id, +void PipelineParallel::BuildPipelineStage(const std::vector> &recv_shape, int device_id, std::vector> &&chunks) { - pipeline_stage_ - = std::make_shared(rank_, num_stages_, recv_shape, optimizer, device_id, std::move(chunks)); + pipeline_stage_ = std::make_shared(rank_, num_stages_, recv_shape, device_id, std::move(chunks)); } void PipelineParallel::SetupSchedule(int num_micro_batches) { @@ -31,14 +28,15 @@ void PipelineParallel::SetupSchedule(int num_micro_batches) { float PipelineParallel::TrainStep(const std::vector> &input, const std::vector> &target, - const std::shared_ptr &loss_fn, DataType dtype) { + const std::shared_ptr &optimizer, const std::shared_ptr &loss_fn, + DataType dtype) { std::shared_ptr stage_input; std::shared_ptr stage_target = target[0]; if (rank_ == 0) { stage_input = input[0]; } - return schedule_->Step(stage_input, stage_target, loss_fn, dtype); + return schedule_->Step(stage_input, stage_target, optimizer, loss_fn, dtype); } StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int rank, int chunks_per_stage) { @@ -79,8 +77,8 @@ StageInfo PipelineParallel::GetStageInfo(int total_layers, int pp_size, int rank } PipelineParallel::PipelineParallel(const std::shared_ptr module, int num_stages, int num_micro_batches, - const std::vector> &recv_shape, int pp_rank, - const std::shared_ptr &optimizer, int device_id, int chunk_size) + const std::vector> &recv_shape, int pp_rank, int device_id, + int chunk_size) : num_stages_(num_stages), rank_(pp_rank) { modules_[kModuleName] = std::move(module); @@ -100,7 +98,7 @@ PipelineParallel::PipelineParallel(const std::shared_ptr module, int num chunks.push_back(std::make_shared(std::move(chunk_parts))); } - BuildPipelineStage(optimizer, recv_shape, device_id, std::move(chunks)); + BuildPipelineStage(recv_shape, device_id, std::move(chunks)); SetupSchedule(num_micro_batches); } diff --git a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc index 95dd3bbc..d2324679 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_schedule.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_schedule.cc @@ -267,7 +267,8 @@ float PipelineSchedule::StepMicroBatches(const std::vector input, std::shared_ptr target, - const std::shared_ptr &loss_fn, DataType dtype) { + const std::shared_ptr &optimizer, const std::shared_ptr &loss_fn, + DataType dtype) { std::vector> micro_batches(num_micro_batches_); std::vector> target_mbs(num_micro_batches_); if (stage_->IsFirstStage()) { @@ -278,8 +279,6 @@ float PipelineSchedule::Step(std::shared_ptr input, std::shared_ptrSplit(target->Dims()[0] / num_micro_batches_); } - const auto &optimizer = stage_->optimizer(); - optimizer->ZeroGrad(); float lossf = StepMicroBatches(micro_batches, target_mbs, loss_fn, dtype); diff --git a/infini_train/src/nn/parallel/pp/pipeline_stage.cc b/infini_train/src/nn/parallel/pp/pipeline_stage.cc index 582b9bd2..78a1aba6 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_stage.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_stage.cc @@ -11,11 +11,10 @@ namespace infini_train::nn::parallel { PipelineStage::PipelineStage(int stage_index /* pp_rank */, int num_stages /* pp_size */, - const std::vector> &recv_shape, std::shared_ptr optimizer, - int device_id, std::vector> &&chunks) + const std::vector> &recv_shape, int device_id, + std::vector> &&chunks) : stage_index_(stage_index), num_stages_(num_stages), prev_rank_(stage_index > 0 ? stage_index - 1 : -1), next_rank_(stage_index < num_stages - 1 ? stage_index + 1 : -1), recv_shape_(recv_shape), - optimizer_(std::move(optimizer)), device_(DeviceManager::Instance()->GetAllAvailableDevices(DeviceType::kCUDA).at(device_id)), chunks_(std::move(chunks)) {} @@ -38,7 +37,7 @@ int PipelineStage::num_stages() const { return num_stages_; } const Device *PipelineStage::device() const { return device_; } const std::vector> &PipelineStage::recv_shape() const { return recv_shape_; } -std::shared_ptr PipelineStage::optimizer() { return optimizer_; } +// std::shared_ptr PipelineStage::optimizer() { return optimizer_; } const std::vector> &PipelineStage::chunks() { return chunks_; } std::vector> *PipelineStage::mutable_chunks() { return &chunks_; } } // namespace infini_train::nn::parallel diff --git a/infini_train/src/nn/parallel/reducer.cc b/infini_train/src/nn/parallel/reducer.cc index 0cdd7703..b27de0ef 100644 --- a/infini_train/src/nn/parallel/reducer.cc +++ b/infini_train/src/nn/parallel/reducer.cc @@ -164,8 +164,8 @@ std::vector> ComputeBucketAssignmentBySize(const std::vector } Reducer::Reducer(std::vector> parameters, std::vector> bucket_indices, - const ReducerOptions &opts) - : params_(std::move(parameters)), opts_(opts) { + const DistributedDataParallelConfig ddp_config) + : params_(std::move(parameters)), ddp_config_(ddp_config) { BuildBuckets(bucket_indices); ready_seen_this_iter_.assign(params_.size(), 0); } @@ -263,8 +263,8 @@ void Reducer::RebuildBuckets() { tensors_in_order.push_back(params_[global_idx]); } - const size_t first_cap_bytes = opts_.first_bucket_cap_mb * kBytesPerMB; - const size_t normal_cap_bytes = opts_.normal_bucket_cap_mb * kBytesPerMB; + const size_t first_cap_bytes = ddp_config_.first_bucket_cap_mb * kBytesPerMB; + const size_t normal_cap_bytes = ddp_config_.normal_bucket_cap_mb * kBytesPerMB; std::vector bucket_size_limits = {first_cap_bytes, normal_cap_bytes}; auto new_bucket_indices = ComputeBucketAssignmentBySize(tensors_in_order, bucket_size_limits, full_order); @@ -299,7 +299,7 @@ void Reducer::PrepareForBackward() { for (auto &bucket : buckets_) { bucket.pending = bucket.variables.size(); - if (opts_.gradient_as_bucket_view) { + if (ddp_config_.gradient_as_bucket_view) { for (size_t i = 0; i < bucket.variables.size(); ++i) { // Tie each param.grad to slice of contents const auto ¶m = bucket.variables[i]; @@ -358,7 +358,7 @@ void Reducer::MarkVariableReadyDense(size_t variable_index) { ready_seen_this_iter_[variable_index] = 1; } - if (!opts_.gradient_as_bucket_view) { + if (!ddp_config_.gradient_as_bucket_view) { auto grad = bucket.variables[loc.intra_bucket_index]->grad(); CHECK(grad && grad->Dtype() == bucket.dtype && grad->GetDevice() == bucket.contents->GetDevice()); CopyGradToBucket(grad, bucket.contents, bucket.offsets[loc.intra_bucket_index]); @@ -447,7 +447,7 @@ void Reducer::FinalizeBackward() { if (!bucket.work) { continue; } - if (!opts_.gradient_as_bucket_view) { + if (!ddp_config_.gradient_as_bucket_view) { for (size_t i = 0; i < bucket.variables.size(); ++i) { // NOTE(zbl): For better performance, try to directly assgin bucket slice to grad instead of copying // i.e. bucket.variables[i]->set_grad(bucket.bucket_views_in[i]); diff --git a/infini_train/src/optimizer.cc b/infini_train/src/optimizer.cc index 80e3887f..d24c88e4 100644 --- a/infini_train/src/optimizer.cc +++ b/infini_train/src/optimizer.cc @@ -1,5 +1,6 @@ #include "infini_train/include/optimizer.h" +#include #include #include "infini_train/include/device.h" diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index c5de11ce..d0e23431 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -113,6 +113,22 @@ Tensor::Tensor(const float *data, const std::vector &dims, DataType dty } } +void Tensor::SetData(const Tensor &tensor, size_t offset, bool overwrite) { + CHECK(tensor.GetDevice() == GetDevice()); + CHECK(tensor.Dtype() == Dtype()); + CHECK_LE(tensor.offset_ + offset + SizeInBytes(), tensor.buffer_->Size()); + + if (overwrite) { + // Create a view of original tensor buffer + auto new_tensor = Tensor(tensor, offset, Dims()); + // Copy in data + new_tensor.CopyFrom(*this); + } + + buffer_ = tensor.buffer_; + offset_ = tensor.offset_ + offset; +} + const Device *Tensor::GetDevice() const { return buffer_->GetDevice(); } void *Tensor::DataPtr() { return reinterpret_cast(buffer_->DataPtr()) + offset_; } @@ -579,6 +595,18 @@ void Tensor::set_grad(const std::shared_ptr &grad) { } } +std::shared_ptr Tensor::main_grad() const { return main_grad_; }; +void Tensor::set_main_grad(const std::shared_ptr &grad) { + if (grad) { + CHECK(grad->GetDevice() == GetDevice()); + CHECK(grad->Dtype() == Dtype()); + CHECK(grad->Dims() == Dims()); + main_grad_ = grad; + } else { + main_grad_.reset(); + } +} + bool Tensor::requires_grad() const { return requires_grad_; } void Tensor::set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } diff --git a/scripts/write_to_feishu_sheet.py b/scripts/write_to_feishu_sheet.py index 93a63aec..28c41661 100644 --- a/scripts/write_to_feishu_sheet.py +++ b/scripts/write_to_feishu_sheet.py @@ -3,11 +3,14 @@ import time import os import argparse +import glob import re import pandas as pd from datetime import datetime, date import subprocess +# date/branch/commit/avg_latency/avg_throughput/peak_used/peak_reserved +META_COLS=7 HEADER_ROWS=5 HEADER_COLS="W" @@ -116,13 +119,86 @@ def set_style(self, spreadsheet_token, sheet_id, entry_index): return self._feishu_request("PUT", f"/sheets/v2/spreadsheets/{spreadsheet_token}/styles_batch_update", json=payload) is not None def merge_columns(self, spreadsheet_token, sheet_id): - """Merge columns A5:E9""" + """Merge columns A5:G9""" # API reference:https://open.feishu.cn/document/server-docs/docs/sheets-v3/data-operation/merge-cells start = HEADER_ROWS end = HEADER_ROWS + 4 - payload = {"range": f"{sheet_id}!A{start}:E{end}", "mergeType": "MERGE_COLUMNS"} + payload = {"range": f"{sheet_id}!A{start}:G{end}", "mergeType": "MERGE_COLUMNS"} return self._feishu_request("POST", f"/sheets/v2/spreadsheets/{spreadsheet_token}/merge_cells", json=payload) is not None + def write_cmd_args_to_header(self, spreadsheet_token, cmd_args, sheet_id): + """Write command args to A1:W1""" + def col_letter_to_idx(letter: str) -> int: + """A->1, Z->26, AA->27 ...""" + letter = letter.strip().upper() + idx = 0 + for ch in letter: + idx = idx * 26 + (ord(ch) - ord('A') + 1) + return idx + + data = [cmd_args] + [""] * (col_letter_to_idx(HEADER_COLS) - 1) + payload = {"valueRange": {"range": f"{sheet_id}!A1:W1", "values": [data]}} + return self._feishu_request("PUT", f"/sheets/v2/spreadsheets/{spreadsheet_token}/values", json=payload) is not None + + def create_sheet_for_testcase(self, spreadsheet_token, sheet_title, template_sheet_id): + """Create a sheet from template given a specific title""" + payload = { + "requests": [ + { + "copySheet": { + "source": {"sheetId": template_sheet_id}, + "destination": {"title": sheet_title} + } + } + ] + } + resp = self._feishu_request("POST", f"/sheets/v2/spreadsheets/{spreadsheet_token}/sheets_batch_update", json=payload) + if resp: + try: + new_sheet_id = resp["data"]["replies"][0]["copySheet"]["properties"]["sheetId"] + return new_sheet_id + except Exception: + print("Unexpected copySheet response:", resp) + return None + else: + return None + + def sort_sheets_by_title(self, spreadsheet_token, template_title = "模板") -> bool: + sheets = self.get_all_sheet_ids(spreadsheet_token) + + template = None + normal = [] + for s in sheets: + if s["title"] == template_title: + template = s + else: + normal.append(s) + + def natural_key(s: str): + return [int(x) if x.isdigit() else x.lower() for x in re.split(r'(\d+)', s)] + + normal.sort(key=lambda x: natural_key(x["title"])) + ordered = normal + ([template] if template else []) + + requests_ = [] + for new_index, s in enumerate(ordered): + sheet_id = s["sheet_id"] + requests_.append({ + "updateSheet": { + "properties": { + "sheetId": sheet_id, + "index": new_index + } + } + }) + + if not requests_: + return True + + payload = {"requests": requests_} + return self._feishu_request("POST", f"/sheets/v2/spreadsheets/{spreadsheet_token}/sheets_batch_update", json=payload) is not None + + def post_process(self, spreadsheet_token, sheet_id): """Post-processing: set styles and merge cells""" row_count = self.get_sheet_row_count(spreadsheet_token, sheet_id) @@ -175,11 +251,38 @@ def load_config(config_file): return config +def parse_command_args(log_content: str, start_flag="--dtype"): + """Parse command-line arguments from [COMMAND] line""" + for line in log_content.splitlines(): + if line.startswith("[COMMAND]"): + idx = line.find(start_flag) + if idx != -1: + return line[idx:].strip() + return None + return None def parse_training_log(log_content): - """Parse training log to extract avg latency and throughput from step >= 2""" - pattern = r"step\s+(\d+)/\d+\s+\|.*?\|\s+\(\s*(\d+\.\d+)\s+ms\s+\|\s+(\d+)\s+tok/s.*?\)" - matches = re.findall(pattern, log_content) + """Parse training log to extract avg latency and throughput from step >= 2 and peak mem usage during whole time""" + pattern_with_peak = ( + r"step\s+(\d+)/\d+\s+\|.*?\|\s+\(\s*" + r"(\d+(?:\.\d+)?)\s*ms\s*\|\s*" + r"(\d+(?:\.\d+)?)\s*tok/s\s*\|\s*" + r"peak used:\s*(\d+)\s*MB\s*\|\s*" + r"peak reserved:\s*(\d+)\s*MB" + ) + + # NOTE(zbl): This is for compatibility reasons + pattern_no_peak = ( + r"step\s+(\d+)/\d+\s+\|.*?\|\s+\(\s*" + r"(\d+(?:\.\d+)?)\s*ms\s*\|\s*" + r"(\d+(?:\.\d+)?)\s*tok/s" + ) + + matches = re.findall(pattern_with_peak, log_content) + has_peak = True + if not matches: + matches = re.findall(pattern_no_peak, log_content) + has_peak = False filtered = [m for m in matches if int(m[0]) > 1] if not filtered: @@ -192,7 +295,15 @@ def parse_training_log(log_content): avg_latency = round(sum(latencies) / len(latencies), 2) avg_throughput = round(sum(throughputs) / len(throughputs), 2) - return [avg_latency, avg_throughput] + peak_used_max = None + peak_reserved_max = None + if has_peak: + peak_used = [int(m[3]) for m in filtered] + peak_reserved = [int(m[4]) for m in filtered] + peak_used_max = max(peak_used) + peak_reserved_max = max(peak_reserved) + + return [avg_latency, avg_throughput, peak_used_max, peak_reserved_max] def parse_profile_report(profile_content): @@ -265,6 +376,20 @@ def parse_profile_report(profile_content): return merged_df.head(5).iloc[:, :16] return None +def discover_testcases(model_name: str, log_dir="logs"): + """Get all test case id from local log dir""" + pattern = os.path.join(log_dir, f"{model_name}_*.log") + files = glob.glob(pattern) + testcases = [] + prefix = f"{model_name}_" + for path in files: + base = os.path.basename(path) + if not base.startswith(prefix) or base.endswith("_profile.log") or not base.endswith(".log"): + continue + testcase = base[len(prefix):-len(".log")] + if testcase: + testcases.append(testcase) + return sorted(set(testcases)) def get_git_branch(): """Get current git branch""" @@ -289,14 +414,16 @@ def get_model_data(model_name, sheet_title): log_file_path = f"logs/{model_name}_{sheet_title}.log" profile_file_path = f"profile_logs/{model_name}_{sheet_title}_profile_{model_name}.report.rank0" - avg_latency, avg_throughput = None, None + avg_latency, avg_throughput, peak_used_max, peak_reserved_max = None, None, None, None # Read training log if os.path.exists(log_file_path): with open(log_file_path, 'r', encoding='utf-8') as f: - result = parse_training_log(f.read()) + content = f.read() + result = parse_training_log(content) if result: - avg_latency, avg_throughput = result + avg_latency, avg_throughput, peak_used_max, peak_reserved_max = result + cmd_args = parse_command_args(content) else: print(f"Training log does not exist: {log_file_path}") @@ -311,12 +438,12 @@ def get_model_data(model_name, sheet_title): if report_df is None: return [] - # Insert 5 empty columns at the front - new_data = [["" for _ in range(5)] for _ in range(5)] + # Insert $META_COLS empty columns at the front + new_data = [["" for _ in range(META_COLS)] for _ in range(5)] new_df = pd.DataFrame(new_data, index=report_df.index) combined_df = pd.concat([new_df, report_df], axis=1) - # Fill first row's first 5 columns with info + # Fill first row's first $META_COLS columns with info combined_df.iloc[0, 0] = FeishuSheetHandler.convert_to_feishu_date(datetime.now().date()) combined_df.iloc[0, 1] = get_git_branch() combined_df.iloc[0, 2] = get_git_commit_id() @@ -324,8 +451,12 @@ def get_model_data(model_name, sheet_title): combined_df.iloc[0, 3] = avg_latency if avg_throughput is not None: combined_df.iloc[0, 4] = avg_throughput + if peak_used_max is not None: + combined_df.iloc[0, 5] = peak_used_max + if peak_reserved_max is not None: + combined_df.iloc[0, 6] = peak_reserved_max - return combined_df.values.tolist() + return cmd_args, combined_df.values.tolist() def main(): @@ -350,26 +481,54 @@ def main(): print(f"\n=== Start processing {model_name} ===") model_name = model_name.lower() - model_sheets = handler.get_all_sheet_ids(spreadsheet_token) - if not model_sheets: - print(f"No sheets retrieved for {model_name}, skipping") + testcases = discover_testcases(model_name) + if not testcases: + print(f"No local testcases found under logs/ for model={model_name}, skipping") continue + print(f"Discovered {len(testcases)} local testcases: {testcases}") - print(f"Found {len(model_sheets)} sheets in {model_name}'s spreadsheet") - - for sheet in model_sheets: - if sheet["title"] == "模板": - continue + remote_sheets = handler.get_all_sheet_ids(spreadsheet_token) + remote_by_title = {s["title"]: s["sheet_id"] for s in remote_sheets} - print(f"\nProcessing sheet {sheet['index']}: {sheet['title']} (ID: {sheet['sheet_id']})") + if "模板" not in remote_by_title: + print(f"No template sheets retrieved for {model_name}, skipping") + continue + template_sheet_id = remote_by_title["模板"] + + sort_sheets = False + + for testcase in testcases: + print("\n-------") + sheet_id = remote_by_title.get(testcase) + write_cmd = False + + if not sheet_id: + print(f"Sheet for '{testcase}' not found, creating from template...") + sheet_id = handler.create_sheet_for_testcase(spreadsheet_token, sheet_title=testcase, template_sheet_id=template_sheet_id) + if not sheet_id: + print(f"Failed to create sheet '{testcase}', skipping") + continue + remote_by_title[testcase] = sheet_id + sort_sheets = True + write_cmd = True + print(f"Created sheet '{testcase}' with id={sheet_id}") + + print(f"Processing testcase '{testcase}' -> sheet_id={sheet_id}") + + cmd_args, sheet_data = get_model_data(model_name=model_name, sheet_title=testcase) - sheet_data = get_model_data(model_name=model_name, sheet_title=sheet['title']) if not sheet_data: print("No valid data generated, skipping") continue - if handler.prepend_data(spreadsheet_token, sheet["sheet_id"], sheet_data): - handler.post_process(spreadsheet_token, sheet["sheet_id"]) + if write_cmd and cmd_args: + handler.write_cmd_args_to_header(spreadsheet_token, cmd_args, sheet_id) + + if handler.prepend_data(spreadsheet_token, sheet_id, sheet_data): + handler.post_process(spreadsheet_token, sheet_id) + + if sort_sheets: + handler.sort_sheets_by_title(spreadsheet_token, "模板") print("\n=== All models and sheets processed ===")