Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 76 additions & 9 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ static inline void mscclppNcclDlopenFinalize() {
}

static inline int mscclppNcclInFallbackList(const char* collOps, const char* fallbackList) {
if (fallbackList == nullptr || fallbackList[0] == '\0' || strcmp(fallbackList, "all") == 0) {
if (strcmp(fallbackList, "all") == 0) {
return 1;
}

Expand Down Expand Up @@ -207,6 +207,7 @@ struct ncclComm {
uint32_t buffFlag;

int nRanksPerNode;
int worldSize;

std::shared_ptr<uint32_t> deviceFlag7;
std::shared_ptr<uint32_t> deviceFlag28;
Expand Down Expand Up @@ -703,10 +704,15 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
commPtr->comm = mscclppComm;
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);
commPtr->nRanksPerNode = mscclppComm->bootstrap()->getNranksPerNode();
commPtr->worldSize = mscclppComm->bootstrap()->getNranks();

if (commPtr->worldSize == 1) {
*comm = commPtr;
return ncclSuccess;
}

// FallBack for single node
if (mscclppComm->bootstrap()->getNranks() == mscclppComm->bootstrap()->getNranksPerNode())
ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank);
if (commPtr->worldSize == commPtr->nRanksPerNode) ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank);

const std::string& collectiveDir = mscclpp::env()->executionPlanDir;
if (collectiveDir != "") {
Expand Down Expand Up @@ -759,7 +765,12 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
return ncclSuccess;
}

NCCL_API ncclResult_t ncclCommInitAll(ncclComm_t*, int, const int*) {
NCCL_API ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int*) {
if (ndev == 1) {
ncclUniqueId Id;
ncclGetUniqueId(&Id);
return ncclCommInitRank(comm, ndev, Id, 0);
}
// TODO: implement this function
WARN("ncclCommInitAll is currently unavailable");
return ncclInternalError;
Expand Down Expand Up @@ -987,6 +998,14 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff

NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
int root, ncclComm_t comm, cudaStream_t stream) {
if (comm->worldSize == 1) {
if (sendbuff != recvbuff) {
size_t bytes = count * ncclTypeSize(datatype);
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
}
return ncclSuccess;
}

size_t bytes = count * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
Expand All @@ -996,7 +1015,7 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
}

int rank = comm->comm->bootstrap()->getRank();
INFO(MSCCLPP_INIT, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
INFO(MSCCLPP_NCCL, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
recvbuff, count, datatype, comm);

const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
Expand Down Expand Up @@ -1047,6 +1066,13 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
if (comm->worldSize == 1) {
if (sendbuff != recvbuff) {
size_t bytes = count * ncclTypeSize(datatype);
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
}
return ncclSuccess;
}
// Checking if the parameters are valids
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr) {
WARN(
Expand Down Expand Up @@ -1076,8 +1102,17 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
}
}

if (plan == nullptr)
int nRanks = comm->comm->bootstrap()->getNranks();
int nRanksPerNode = comm->comm->bootstrap()->getNranksPerNode();
if (plan == nullptr && nRanks == nRanksPerNode)
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
if (plan == nullptr && mscclppNcclDlopenSharedLib) {
return mscclppNcclOps.AllReduce(sendbuff, recvbuff, count, datatype, reductionOperation,
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
} else if (plan == nullptr) {
WARN("No FallBack code for AllReduce when multi-node");
return ncclInternalError;
}

switch (datatype) {
case ncclFloat16:
Expand Down Expand Up @@ -1107,6 +1142,14 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t

NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
if (comm->worldSize == 1) {
if (sendbuff != recvbuff) {
size_t bytes = recvcount * ncclTypeSize(datatype);
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
}
return ncclSuccess;
}

size_t bytes = recvcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
Expand Down Expand Up @@ -1169,6 +1212,13 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si

NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream) {
if (comm->worldSize == 1) {
if (sendbuff != recvbuff) {
size_t bytes = sendcount * ncclTypeSize(datatype);
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
}
return ncclSuccess;
}
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
WARN(
Expand Down Expand Up @@ -1239,14 +1289,31 @@ NCCL_API ncclResult_t ncclRecv(void*, size_t, ncclDataType_t, int, ncclComm_t, c
return ncclInternalError;
}

NCCL_API ncclResult_t ncclAllToAll(const void*, void*, size_t, ncclDataType_t, ncclComm_t, cudaStream_t) {
NCCL_API ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclComm_t comm, cudaStream_t stream) {
if (comm->worldSize == 1) {
if (sendbuff != recvbuff) {
size_t bytes = count * ncclTypeSize(datatype);
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
}
return ncclSuccess;
}
// TODO: implement this function
WARN("ncclAllToAll is currently unavailable");
return ncclInternalError;
}

NCCL_API ncclResult_t ncclAllToAllv(const void*, const size_t[], const size_t[], void*, const size_t[], const size_t[],
ncclDataType_t, ncclComm_t, cudaStream_t) {
NCCL_API ncclResult_t ncclAllToAllv(const void* sendbuff, [[maybe_unused]] const size_t sendcounts[],
const size_t sdispls[], void* recvbuff, const size_t recvcounts[],
const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm,
cudaStream_t stream) {
if (comm->worldSize == 1) {
size_t bytes = recvcounts[0] * ncclTypeSize(datatype);
MSCCLPP_CUDATHROW(cudaMemcpyAsync((char*)recvbuff + rdispls[0] * ncclTypeSize(datatype),
(const char*)sendbuff + sdispls[0] * ncclTypeSize(datatype), bytes,
cudaMemcpyDeviceToDevice, stream));
return ncclSuccess;
}
WARN("ncclAllToAllv is currently unavailable");
return ncclInternalError;
}
Expand Down
13 changes: 12 additions & 1 deletion src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "api.h"
#include "debug.h"
#include "socket.h"
#include "unix_socket.hpp"
#include "utils_internal.hpp"

namespace mscclpp {
Expand Down Expand Up @@ -114,6 +115,8 @@ class TcpBootstrap::Impl {
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);

UnixSocketServer& unixSocketServer_;

static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);

Expand Down Expand Up @@ -149,7 +152,8 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks)
peerCommAddresses_(nRanks, SocketAddress()),
barrierArr_(nRanks, 0),
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}
abortFlag_(abortFlagStorage_.get()),
unixSocketServer_(UnixSocketServer::instance()) {}

UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }

Expand All @@ -172,6 +176,9 @@ void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec
SocketToString(&uniqueId_.addr, line);
INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line);
establishConnections(timeoutSec);

unixSocketServer_.start();
INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_);
}

void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
Expand Down Expand Up @@ -204,6 +211,8 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim
}

establishConnections(timeoutSec);
unixSocketServer_.start();
INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_);
}

TcpBootstrap::Impl::~Impl() {
Expand Down Expand Up @@ -567,6 +576,8 @@ void TcpBootstrap::Impl::close() {
ringSendSocket_.reset(nullptr);
peerSendSockets_.clear();
peerRecvSockets_.clear();
unixSocketServer_.stop();
UnixSocketClient::instance().reset();
}

MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }
Expand Down
4 changes: 2 additions & 2 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct TransportInfo {
char shareableHandle[64];
struct {
// These are only defined for multicast (NVLS) capability
pid_t rootPid;
int fileDesc;
int rootFd;
int rootPid;
};
};
size_t offsetFromBase;
Expand Down
59 changes: 59 additions & 0 deletions src/include/unix_socket.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

#ifndef MSCCLPP_UNIX_SOCKET_HPP_
#define MSCCLPP_UNIX_SOCKET_HPP_

#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>

namespace mscclpp {

class UnixSocketServer {
public:
static UnixSocketServer& instance();
static std::string generateSocketPath(int socketId);

void start();
void stop();
int registerFd(int fd);
void unregisterFd(int fd);
std::string getSocketPath() const;

private:
int listenUnixSockFd_ = -1;
std::string listenUnixSockPath_;
std::thread mainThread_;
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::mutex mutex_;
std::unordered_set<int> fdSet_;

UnixSocketServer();
void mainLoop(int listenUnixSockFd);
};

class UnixSocketClient {
public:
static UnixSocketClient& instance();

int requestFd(const std::string& socketPath, uint32_t fdId);
void reset();
~UnixSocketClient();

private:
std::unordered_map<std::string, int> cachedFds_;
std::mutex mutex_;

UnixSocketClient() = default;
int requestFdInternal(int connFd, uint32_t fdId);
};

} // namespace mscclpp

#endif // MSCCLPP_UNIX_SOCKET_HPP_
29 changes: 12 additions & 17 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "api.h"
#include "debug.h"
#include "endpoint.hpp"
#include "unix_socket.hpp"

namespace mscclpp {

Expand Down Expand Up @@ -42,9 +43,12 @@ class NvlsConnection::Impl : public std::enable_shared_from_this<NvlsConnection:
size_t minMcGran_;
size_t mcGran_;
// These are only defined for multicast (NVLS) capability
pid_t rootPid_;
int rootFd_;
int rootPid_;
int mcFileDesc_;

UnixSocketClient& socketClient_ = UnixSocketClient::instance();

std::list<std::pair<size_t, size_t>> allocatedRanges_;
std::list<std::pair<size_t, size_t>> freeRanges_;
};
Expand All @@ -67,11 +71,8 @@ NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) {
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(&mcFileDesc_, mcHandle_, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
freeRanges_.emplace_back(0, bufferSize_);

rootPid_ = getpid();
if (rootPid_ < 0) {
throw mscclpp::SysError("getpid() failed", errno);
}
rootFd_ = UnixSocketServer::instance().registerFd(mcFileDesc_);

INFO(MSCCLPP_COLL,
"NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld buffer size is "
Expand All @@ -91,29 +92,23 @@ NvlsConnection::Impl::Impl(const std::vector<char>& data) {
it += sizeof(this->mcGran_);
std::copy_n(it, sizeof(this->rootPid_), reinterpret_cast<char*>(&this->rootPid_));
it += sizeof(this->rootPid_);
std::copy_n(it, sizeof(this->mcFileDesc_), reinterpret_cast<char*>(&this->mcFileDesc_));
std::copy_n(it, sizeof(this->rootFd_), reinterpret_cast<char*>(&this->rootFd_));

freeRanges_.emplace_back(0, bufferSize_);
int rootPidFd = syscall(SYS_pidfd_open, rootPid_, 0);
if (rootPidFd < 0) {
throw mscclpp::SysError("pidfd_open() failed", errno);
}
int mcRootFileDescFd = syscall(SYS_pidfd_getfd, rootPidFd, mcFileDesc_, 0);
if (mcRootFileDescFd < 0) {
throw mscclpp::SysError("pidfd_getfd() failed", errno);
}
int mcRootFileDescFd = socketClient_.requestFd(UnixSocketServer::generateSocketPath(this->rootPid_), rootFd_);
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&mcHandle_, reinterpret_cast<void*>(mcRootFileDescFd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(rootPidFd);
close(mcRootFileDescFd);

INFO(MSCCLPP_COLL, "NVLS handle was imported from root");
}

NvlsConnection::Impl::~Impl() {
// we don't need to free multicast handle object according to NCCL.
if (rootPid_ == getpid()) {
if (mcFileDesc_ >= 0) {
UnixSocketServer::instance().unregisterFd(rootFd_);
close(mcFileDesc_);
mcFileDesc_ = -1;
}
}

Expand All @@ -124,7 +119,7 @@ std::vector<char> NvlsConnection::Impl::serialize() {
std::copy_n(reinterpret_cast<char*>(&minMcGran_), sizeof(minMcGran_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&mcGran_), sizeof(mcGran_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&rootPid_), sizeof(rootPid_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&mcFileDesc_), sizeof(mcFileDesc_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&rootFd_), sizeof(rootFd_), std::back_inserter(result));
return result;
}

Expand Down
Loading
Loading