diff --git a/apps/nccl/src/allgather.cu b/apps/nccl/src/allgather.cu index eb40d68b9..60e38f2c4 100644 --- a/apps/nccl/src/allgather.cu +++ b/apps/nccl/src/allgather.cu @@ -19,6 +19,7 @@ AllgatherAlgo6::AllgatherAlgo6() : disableChannelCache_(false) { void AllgatherAlgo6::initialize(std::shared_ptr comm, std::unordered_map>&) { this->conns_ = setupConnections(comm); + this->memorySemaphores_ = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_)); } ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr ctx, const void* input, @@ -61,15 +62,13 @@ ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr AllgatherAlgo6::initAllgatherContext(std::shared_ptr comm, const void*, void* output, size_t count, ncclDataType_t dtype) { - constexpr int nChannelsPerConnection = 35; - auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); ctx->workSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // setup semaphores - ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection)); + ctx->memorySemaphores = this->memorySemaphores_; size_t bytes = count * ncclTypeSize(dtype); size_t recvBytes; @@ -88,7 +87,7 @@ std::shared_ptr AllgatherAlgo6::initAllgatherContext(std: comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc); std::vector remoteMemories = setupRemoteMemories(comm, ctx->rank, localMemory); ctx->memoryChannels = std::move( - setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection)); + setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection_)); ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); // keep registered memories reference diff --git a/apps/nccl/src/allgather.hpp b/apps/nccl/src/allgather.hpp index 174261abe..8dc06addf 100644 --- a/apps/nccl/src/allgather.hpp +++ b/apps/nccl/src/allgather.hpp @@ -217,6 +217,8 @@ class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder { private: bool disableChannelCache_; std::vector> conns_; + std::vector> memorySemaphores_; + const int nChannelsPerConnection_ = 35; void initialize(std::shared_ptr comm, std::unordered_map>&); ncclResult_t allgatherKernelFunc(const std::shared_ptr ctx, const void* input, void* output, diff --git a/apps/nccl/src/allreduce.cu b/apps/nccl/src/allreduce.cu index e255d8c7a..0a48db61b 100644 --- a/apps/nccl/src/allreduce.cu +++ b/apps/nccl/src/allreduce.cu @@ -460,7 +460,8 @@ void Allreduce8::initialize(std::shared_ptr comm, nChannelsPerConnection_ = 64; comm_ = comm; // setup semaphores - this->deviceSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + this->outputSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + this->inputScratchSemaphores_ = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); mscclpp::RegisteredMemory localMemory = comm->registerMemory(scratchBuffer_.get(), scratchBufferSize_, mscclpp::Transport::CudaIpc); this->remoteScratchMemories_ = setupRemoteMemories(comm, comm->bootstrap()->getRank(), localMemory); @@ -483,8 +484,9 @@ ncclResult_t Allreduce8::allreduceKernelFunc(const std::shared_ptrregisterMemory(const_cast(input), bytes, mscclpp::Transport::CudaIpc); - std::vector channels = setupMemoryChannels( - this->conns_, this->deviceSemaphores_, this->remoteScratchMemories_, localMemory, nChannelsPerConnection_); + std::vector channels = + setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory, + nChannelsPerConnection_); this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels)); } inputChannelHandles = this->memoryChannelsMap_[input].second; @@ -525,7 +527,7 @@ std::shared_ptr Allreduce8::initAllreduceContext(std::sha ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // setup semaphores - ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection_); + ctx->memorySemaphores = this->outputSemaphores_; // setup memories and channels size_t recvBytes; CUdeviceptr recvBasePtr; diff --git a/apps/nccl/src/allreduce.hpp b/apps/nccl/src/allreduce.hpp index 3078d2647..e2315146f 100644 --- a/apps/nccl/src/allreduce.hpp +++ b/apps/nccl/src/allreduce.hpp @@ -887,7 +887,8 @@ class Allreduce8 : public mscclpp::AlgorithmBuilder { int nChannelsPerConnection_; std::vector> conns_; std::shared_ptr scratchBuffer_; - std::vector> deviceSemaphores_; + std::vector> outputSemaphores_; + std::vector> inputScratchSemaphores_; std::vector remoteScratchMemories_; mscclpp::RegisteredMemory localScratchMemory_; std::unordered_map, diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index b34071c22..71f1449f9 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -875,10 +875,7 @@ NCCL_API ncclResult_t ncclAllToAllv(const void* sendbuff, [[maybe_unused]] const } NCCL_API ncclResult_t ncclGroupStart() { - if (!tryLoadNcclSharedLib()) { - WARN("Failed to load the shared library for nccl/rccl"); - return ncclInternalError; - } + tryLoadNcclSharedLib(); if (mscclppNcclDlopenSharedLib == true) { return mscclppNcclOps.GroupStart(); }