Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions apps/nccl/src/allgather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ AllgatherAlgo6::AllgatherAlgo6() : disableChannelCache_(false) {
void AllgatherAlgo6::initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>&) {
this->conns_ = setupConnections(comm);
this->memorySemaphores_ = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_));
}

ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
Expand Down Expand Up @@ -61,15 +62,13 @@ ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::
std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
const void*, void* output, size_t count,
ncclDataType_t dtype) {
constexpr int nChannelsPerConnection = 35;

auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();

// setup semaphores
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));
ctx->memorySemaphores = this->memorySemaphores_;

size_t bytes = count * ncclTypeSize(dtype);
size_t recvBytes;
Expand All @@ -88,7 +87,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std:
comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm, ctx->rank, localMemory);
ctx->memoryChannels = std::move(
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection));
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection_));
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);

// keep registered memories reference
Expand Down
2 changes: 2 additions & 0 deletions apps/nccl/src/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
private:
bool disableChannelCache_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> memorySemaphores_;
const int nChannelsPerConnection_ = 35;

void initialize(std::shared_ptr<mscclpp::Communicator> comm, std::unordered_map<std::string, std::shared_ptr<void>>&);
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
Expand Down
10 changes: 6 additions & 4 deletions apps/nccl/src/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ void Allreduce8::initialize(std::shared_ptr<mscclpp::Communicator> comm,
nChannelsPerConnection_ = 64;
comm_ = comm;
// setup semaphores
this->deviceSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
this->outputSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
this->inputScratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
mscclpp::RegisteredMemory localMemory =
comm->registerMemory(scratchBuffer_.get(), scratchBufferSize_, mscclpp::Transport::CudaIpc);
this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory);
Expand All @@ -483,8 +484,9 @@ ncclResult_t Allreduce8::allreduceKernelFunc(const std::shared_ptr<mscclpp::Algo
} else {
mscclpp::RegisteredMemory localMemory =
comm_->registerMemory(const_cast<void*>(input), bytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::MemoryChannel> channels = setupMemoryChannels(
this->conns_, this->deviceSemaphores_, this->remoteScratchMemories_, localMemory, nChannelsPerConnection_);
std::vector<mscclpp::MemoryChannel> channels =
setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory,
nChannelsPerConnection_);
this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels));
}
inputChannelHandles = this->memoryChannelsMap_[input].second;
Expand Down Expand Up @@ -525,7 +527,7 @@ std::shared_ptr<mscclpp::AlgorithmCtx> Allreduce8::initAllreduceContext(std::sha
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();

// setup semaphores
ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_);
ctx->memorySemaphores = this->outputSemaphores_;
// setup memories and channels
size_t recvBytes;
CUdeviceptr recvBasePtr;
Expand Down
3 changes: 2 additions & 1 deletion apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,8 @@ class Allreduce8 : public mscclpp::AlgorithmBuilder {
int nChannelsPerConnection_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
std::shared_ptr<char> scratchBuffer_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> deviceSemaphores_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> outputSemaphores_;
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>> inputScratchSemaphores_;
std::vector<mscclpp::RegisteredMemory> remoteScratchMemories_;
mscclpp::RegisteredMemory localScratchMemory_;
std::unordered_map<const void*, std::pair<std::vector<mscclpp::MemoryChannel>,
Expand Down
5 changes: 1 addition & 4 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,7 @@ NCCL_API ncclResult_t ncclAllToAllv(const void* sendbuff, [[maybe_unused]] const
}

NCCL_API ncclResult_t ncclGroupStart() {
if (!tryLoadNcclSharedLib()) {
WARN("Failed to load the shared library for nccl/rccl");
return ncclInternalError;
}
tryLoadNcclSharedLib();
if (mscclppNcclDlopenSharedLib == true) {
return mscclppNcclOps.GroupStart();
}
Expand Down
Loading