Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion infini_train/include/nn/parallel/distributed_data_parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <memory>

#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/reducer.h"

namespace infini_train {
class Tensor;
Expand All @@ -13,9 +14,13 @@ namespace infini_train::nn::parallel {

class DistributedDataParallel : public nn::Module {
public:
DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id);
DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id,
const ReducerOptions &opts = ReducerOptions{});

std::vector<std::shared_ptr<Tensor>> Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) override;

private:
std::shared_ptr<Reducer> reducer_;
};

} // namespace infini_train::nn::parallel
9 changes: 9 additions & 0 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#endif

#include "infini_train/include/nn/parallel/reduce_op_type.h"
#include "infini_train/include/nn/parallel/work.h"

namespace infini_train {
class Tensor;
Expand All @@ -28,8 +29,11 @@ class ProcessGroup {
public:
explicit ProcessGroup(const std::vector<int> &device_indices);

~ProcessGroup();

int GetGroupRank(int thread_rank) const;

// Communication operations
void AllReduce(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;

void AllGather(const std::shared_ptr<Tensor> &output, const std::shared_ptr<Tensor> &input) const;
Expand All @@ -52,11 +56,16 @@ class ProcessGroup {

std::vector<std::shared_ptr<Tensor>> NcclRecv(std::vector<std::shared_ptr<Tensor>> tensors, int src_rank) const;

// Async communication functions
std::shared_ptr<Work> AllReduceAsync(const std::shared_ptr<Tensor> &tensor, function::ReduceOpType reduce_op) const;

private:
std::vector<ncclComm_t> comms_;
std::vector<cudaStream_t> comm_streams_;
std::vector<const Device *> devices_;

std::unordered_map<const Device *, ncclComm_t> device_comm_map_;
std::unordered_map<const Device *, cudaStream_t> device_stream_map_;
std::unordered_map<int, int> thread_group_rank_map_; // thread_rank : group_rank

int comm_size_ = 0;
Expand Down
149 changes: 149 additions & 0 deletions infini_train/include/nn/parallel/reducer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#pragma once

#include <memory>
#include <mutex>
#include <vector>

#include "infini_train/include/autograd/function_hook.h"
#include "infini_train/include/nn/parallel/parallel_functional.h"
#include "infini_train/include/tensor.h"

namespace infini_train::nn::parallel {

// GradBucket passes bucket contents tensor to DDP communication hook.
// ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/comm.hpp
class GradBucket {
public:
explicit GradBucket(const std::vector<std::shared_ptr<Tensor>> &tensors) : tensors_(tensors) {}
const std::vector<std::shared_ptr<Tensor>> &getTensors() const { return tensors_; }

private:
std::vector<std::shared_ptr<Tensor>> tensors_;
};

// Compute bucket assignment according to the size of each tensors and bucket capacity.
// Returns the indices of tensors in the corrsponding bucket, i.e. output[bucket_i] = {tensor_j, tensor_k, ...}
// The index of tensors[idx] assigned to bucket(j and k above) is tensor_indices[idx].
// When tensor_indices is empty, the index of tensors[idx] assigned to bucket(j and k above) is idx itself.
std::vector<std::vector<size_t>> ComputeBucketAssignmentBySize(const std::vector<std::shared_ptr<Tensor>> &tensors,
const std::vector<size_t> &bucket_size_limits,
const std::vector<size_t> &tensor_indices = {});

struct ReducerOptions {
// Pack all Reducer-related args together
// Ref: https://docs.pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html

// Max capacity for each bucket(in MB)
size_t first_bucket_cap_mb = 128;
size_t normal_bucket_cap_mb = 512;

// When set true, map param.grad directly to the slice of bucket.flat(same address in memory) instead of memcpy
bool gradient_as_bucket_view = true;
};

// DDP Reducer that handles gradient bucketing in backward
// ref: https://github.com/pytorch/pytorch/blob/main/torch/csrc/distributed/c10d/reducer.hpp
class Reducer : public std::enable_shared_from_this<Reducer> {
public:
/** @brief Constructor of Reducer
*
* @param parameters A list of parameters for this process's single model replica
* @param bucket_indices The bucket assignment for this reducer
* @param opts Other options, see definition of ReducerOptions
*/
explicit Reducer(std::vector<std::shared_ptr<Tensor>> parameters, std::vector<std::vector<size_t>> bucket_indices,
const ReducerOptions &opts);

// Attach PostAllReduceHooks to params
void AttachHooksToParameters();

// Prepare bucket info for next step
void PrepareForBackward();

// For custom DDP hook to overwrite the default AllReduce. T
// This can be used for algorithms like Gradient Compression/GossipGrad.
// Hook is registered using `Reducer::RegisterCommHook()`.
// TODO(zbl): Leave the placeholder for the moment
void RegisterCommHook(std::shared_ptr<autograd::PostAccumulateGradHook> hook);

// Return every tensor in bucket's flat buffer
std::vector<std::vector<std::shared_ptr<Tensor>>> GetBucketTensors() const;

private:
// A variable locator locates a particular variable in the reducer's buckets
struct VariableLocator {
// Index of the bucket containing the variable in the `buckets_` vector
size_t bucket_index = 0;
// Index of the variable in the bucket
size_t intra_bucket_index = 0;
};

// Bucket used in DDP backward
struct Bucket {
// Gradients of the bucket flattened into a 1-dimensional tensor
std::shared_ptr<Tensor> contents;
DataType dtype;
int device_rank = 0;

// Variables whose gradients are held in this bucket
std::vector<std::shared_ptr<Tensor>> variables;

// Per-variable offset/length into the flattened `gradients` tensor and
// the corresponding `GradBucket` instance for communication hooks
// In terms of element count, not bytes
std::vector<size_t> offsets;
std::vector<size_t> lengths;

// Views into the `gradients` tensor for each individual gradient
std::vector<std::shared_ptr<Tensor>> bucket_views_in;
// TODO(zbl): reserved for occasions where grads have different stride/layout
std::vector<std::shared_ptr<Tensor>> bucket_views_out;

// Number of gradients left to be computed before the bucket is ready to be reduced
size_t pending;

// Global indices of participating variables in the bucket
std::vector<size_t> variable_indices;

// If this bucket should expect a single sparse gradient
// If `true`, then this implies that `bucket.variables.size() == 1`.
// TODO(zbl): support logics for sparse gradient later
bool expect_sparse_gradient = false;
};

private:
void InitializeBuckets(const std::vector<std::vector<size_t>> &bucket_indices);

// NOTE(zbl): all grads are assumed dense and stored continously in bucket for now
void MarkVariableReadyDense(size_t variable_index);
void MarkBucketReady(size_t bucket_index);
void FinalizeBucketDense(size_t bucket_index);

void BuildBuckets(const std::vector<std::vector<size_t>> &bucket_indices);
void InitializeBucketViews(Bucket &bucket);
void RebuildBuckets();

private:
mutable std::mutex mutex_;
std::vector<std::shared_ptr<Tensor>> params_;
std::vector<Bucket> buckets_;
std::vector<VariableLocator> locators_;

std::atomic<size_t> buckets_finished_{0};
std::shared_ptr<autograd::PostAccumulateGradHook> comm_hook_ = nullptr;
ReducerOptions opts_;

// Next bucket to be reduced
// This is to make sure that all-reduce of buckets be launched in the order we expect
size_t next_bucket_ = 0;
// To record the order of params getting ready on first step
std::vector<size_t> grad_ready_order_indices_;
// To record whether each param is ready on first step
std::vector<uint8_t> ready_seen_this_iter_;
// Whether to rebuild buckets on next train step
bool need_rebuild_ = false;
// Whether to buckets have already been rebuilt on the second step
bool has_rebuilt_bucket_ = false;
};

} // namespace infini_train::nn::parallel
72 changes: 72 additions & 0 deletions infini_train/include/nn/parallel/work.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#pragma once

#include <atomic>
#include <chrono>
#include <exception>
#include <memory>
#include <mutex>

#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif
#ifdef USE_NCCL
#include <nccl.h>
#endif

#include "infini_train/include/device.h"

namespace infini_train::nn::parallel {

class Work {
public:
virtual ~Work() = default;

virtual bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) = 0;

virtual bool IsCompleted() const = 0;
virtual bool IsSuccess() const = 0;

virtual void Synchronize() const = 0;

virtual std::exception_ptr exception() const = 0;

virtual void *ready_event() const = 0;
virtual void *done_event() const = 0;
};

#ifdef USE_NCCL
class WorkNccl final : public Work {
public:
WorkNccl(const Device *device, ncclComm_t comm);
~WorkNccl() override;

bool Wait(std::chrono::milliseconds timeout = std::chrono::milliseconds::zero()) override;

bool IsCompleted() const override;
bool IsSuccess() const override;

void Synchronize() const override;

std::exception_ptr exception() const override { return exception_; };

void *ready_event() const override { return reinterpret_cast<void *>(ready_event_); };
void *done_event() const override { return reinterpret_cast<void *>(done_event_); };

private:
bool CheckNcclStatus();
void SetException(std::exception_ptr e);

private:
Device *device_ = nullptr;
cudaEvent_t ready_event_;
cudaEvent_t done_event_;
ncclComm_t comm_;

mutable std::mutex mutex_;
std::exception_ptr exception_;
std::atomic<bool> completed_{false};
std::atomic<bool> success_{false};
};
#endif

} // namespace infini_train::nn::parallel
7 changes: 7 additions & 0 deletions infini_train/include/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
Tensor To(const Device *device);
Tensor To(DataType dtype);

void CopyFrom(const Tensor &src);
void CopyFrom(const std::shared_ptr<Tensor> &src);

// operator overloading
std::shared_ptr<Tensor> Equals(const std::shared_ptr<Tensor> &other);
std::shared_ptr<Tensor> Equals(float scalar);
Expand Down Expand Up @@ -208,6 +211,8 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
void set_output_idx(int output_idx);

void ZeroGrad(bool set_to_none = true);
void MarkGradOverwriteOnNextAccum();
bool ConsumeGradOverwriteFlag();

void Backward(std::shared_ptr<Tensor> gradient = nullptr, bool retain_graph = false,
bool create_graph = false) const;
Expand All @@ -229,6 +234,8 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
// a strong reference to the accumulator to manage its lifetime.
std::shared_ptr<autograd::AccumulateGrad> grad_accumulator_ = nullptr;
std::shared_ptr<autograd::PostAccumulateGradHook> post_accumulate_grad_hook_ = nullptr;

bool grad_overwrite_once_ = false;
};

std::shared_ptr<Tensor> operator==(const std::shared_ptr<Tensor> &t, float scalar);
Expand Down
12 changes: 10 additions & 2 deletions infini_train/src/autograd/accumulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,17 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output

if (grad_output) {
if (grad) {
auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"});
kernel.Call<void>(grad_output, learning_rate_, grad);
if (tensor_->ConsumeGradOverwriteFlag()) {
// If the tensor is marked to overrite its current grad on next grad update
// See notes in `infini_train::nn::parallel::Reducer::PrepareForBackward()`
// NOTE(zbl): must copy, cannot change grad buffer address
grad->CopyFrom(grad_output);
} else {
auto kernel = Dispatcher::Instance().GetKernel({device->Type(), "AccumulateGrad"});
kernel.Call<void>(grad_output, learning_rate_, grad);
}
} else {
// NOTE(zbl): check whether need to do copying instead of slicing
auto new_grad = std::make_shared<Tensor>(*grad_output.get(), 0, grad_output->Dims());
tensor_->set_grad(std::move(new_grad));
}
Expand Down
25 changes: 16 additions & 9 deletions infini_train/src/nn/parallel/distributed_data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,32 @@ namespace {
constexpr char kModuleName[] = "module";
} // namespace

DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id) {
DistributedDataParallel::DistributedDataParallel(std::shared_ptr<nn::Module> module, int device_id,
const ReducerOptions &opts) {
for (auto &param : module->Parameters()) {
auto device = param->GetDevice();
CHECK_EQ(device->Index(), device_id) << "All parameters must be on the same device as the module";

auto ddp_pg
= ProcessGroupFactory::Instance()->Get(GetDataParallelProcessGroupName(device->rank().thread_rank()));
auto hook = std::make_unique<infini_train::autograd::AllReducePostAccumulateHook>(function::ReduceOpType::kAvg,
ddp_pg);
param->RegisterPostAccumulateGradHook(std::move(hook));
CHECK_EQ(param->GetDevice()->Index(), device_id) << "All parameters must be on the same device as the module";
}
for (auto &buffer : module->Buffers()) {
CHECK_EQ(buffer->GetDevice()->Index(), device_id) << "All buffers must be on the same device as the module";
}
modules_[kModuleName] = std::move(module);

// Bucket Assignment
auto params = modules_[kModuleName]->Parameters();
const size_t first_cap_bytes = opts.first_bucket_cap_mb * 1024ULL * 1024ULL;
const size_t normal_cap_bytes = opts.normal_bucket_cap_mb * 1024ULL * 1024ULL;
std::vector<size_t> bucket_size_limits = {first_cap_bytes, normal_cap_bytes};
auto bucket_indices = ComputeBucketAssignmentBySize(params, bucket_size_limits);

reducer_ = std::make_shared<Reducer>(params, bucket_indices, opts);
reducer_->AttachHooksToParameters();
}

std::vector<std::shared_ptr<Tensor>>
DistributedDataParallel::Forward(const std::vector<std::shared_ptr<Tensor>> &input_tensors) {
if (reducer_) {
reducer_->PrepareForBackward();
}
return modules_[kModuleName]->Forward(input_tensors);
}

Expand Down
Loading