Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/mscclpp/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ class Env {
/// Default is false.
const bool forceDisableNvls;

int localRank = -1; // to be set by the bootstrap

private:
Env();

Expand Down
12 changes: 11 additions & 1 deletion src/bootstrap/bootstrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <cstring>
#include <mscclpp/core.hpp>
#include <mscclpp/env.hpp>
#include <mscclpp/errors.hpp>
#include <sstream>
#include <thread>
Expand All @@ -14,6 +15,7 @@
#include "api.h"
#include "debug.h"
#include "socket.h"
#include "unix_socket.hpp"
#include "utils_internal.hpp"

namespace mscclpp {
Expand Down Expand Up @@ -114,6 +116,8 @@ class TcpBootstrap::Impl {
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);

UnixSocketServer& unixSocketServer_;

static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);

Expand Down Expand Up @@ -149,7 +153,8 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks)
peerCommAddresses_(nRanks, SocketAddress()),
barrierArr_(nRanks, 0),
abortFlagStorage_(new uint32_t(0)),
abortFlag_(abortFlagStorage_.get()) {}
abortFlag_(abortFlagStorage_.get()),
unixSocketServer_(UnixSocketServer::instance()) {}

UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }

Expand All @@ -172,6 +177,10 @@ void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec
SocketToString(&uniqueId_.addr, line);
INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line);
establishConnections(timeoutSec);

env()->localRank = rank_ % getNranksPerNode();
unixSocketServer_.start(rank_ % getNranksPerNode());
INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_);
}

void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
Expand Down Expand Up @@ -567,6 +576,7 @@ void TcpBootstrap::Impl::close() {
ringSendSocket_.reset(nullptr);
peerSendSockets_.clear();
peerRecvSockets_.clear();
unixSocketServer_.stop();
}

MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }
Expand Down
4 changes: 2 additions & 2 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ struct TransportInfo {
char shareableHandle[64];
struct {
// These are only defined for multicast (NVLS) capability
pid_t rootPid;
int fileDesc;
int rootFdId;
int rootLocalRankId;
};
};
size_t offsetFromBase;
Expand Down
56 changes: 56 additions & 0 deletions src/include/unix_socket.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#ifndef MSCCLPP_UNIX_SOCKET_HPP_
#define MSCCLPP_UNIX_SOCKET_HPP_

#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <unordered_map>

namespace mscclpp {

class UnixSocketServer {
public:
static UnixSocketServer& instance();
static std::string generateSocketPath(int localRankId);

void start(int localRankId);
void stop();
uint32_t registerFd(int fd);
void unregisterFd(uint32_t fdId);
std::string getSocketPath() const;

private:
int listenUnixSockFd_ = -1;
std::string listenUnixSockPath_;
std::thread mainThread_;
std::unique_ptr<uint32_t> abortFlagStorage_;
volatile uint32_t* abortFlag_;
std::mutex mutex_;
std::unordered_map<uint32_t, int> fdMap_;

UnixSocketServer();
void mainLoop();
};

class UnixSocketClient {
public:
static UnixSocketClient& instance();

int requestFd(const std::string& socketPath, uint32_t fdId);
~UnixSocketClient();

private:
std::unordered_map<std::string, int> cachedFds_;
std::mutex mutex_;

UnixSocketClient() = default;
int requestFdInternal(int connFd, uint32_t fdId);
};

} // namespace mscclpp

#endif // MSCCLPP_UNIX_SOCKET_HPP_
37 changes: 16 additions & 21 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "api.h"
#include "debug.h"
#include "endpoint.hpp"
#include "unix_socket.hpp"

namespace mscclpp {

Expand Down Expand Up @@ -42,9 +43,12 @@ class NvlsConnection::Impl : public std::enable_shared_from_this<NvlsConnection:
size_t minMcGran_;
size_t mcGran_;
// These are only defined for multicast (NVLS) capability
pid_t rootPid_;
int rootFdId_;
int rootLocalRankId_;
int mcFileDesc_;

UnixSocketClient& socketClient_ = UnixSocketClient::instance();

std::list<std::pair<size_t, size_t>> allocatedRanges_;
std::list<std::pair<size_t, size_t>> freeRanges_;
};
Expand All @@ -67,11 +71,8 @@ NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) {
MSCCLPP_CUTHROW(
cuMemExportToShareableHandle(&mcFileDesc_, mcHandle_, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
freeRanges_.emplace_back(0, bufferSize_);

rootPid_ = getpid();
if (rootPid_ < 0) {
throw mscclpp::SysError("getpid() failed", errno);
}
rootLocalRankId_ = env()->localRank;
rootFdId_ = UnixSocketServer::instance().registerFd(mcFileDesc_);

INFO(MSCCLPP_COLL,
"NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld buffer size is "
Expand All @@ -89,31 +90,25 @@ NvlsConnection::Impl::Impl(const std::vector<char>& data) {
it += sizeof(this->minMcGran_);
std::copy_n(it, sizeof(this->mcGran_), reinterpret_cast<char*>(&this->mcGran_));
it += sizeof(this->mcGran_);
std::copy_n(it, sizeof(this->rootPid_), reinterpret_cast<char*>(&this->rootPid_));
it += sizeof(this->rootPid_);
std::copy_n(it, sizeof(this->mcFileDesc_), reinterpret_cast<char*>(&this->mcFileDesc_));
std::copy_n(it, sizeof(this->rootLocalRankId_), reinterpret_cast<char*>(&this->rootLocalRankId_));
it += sizeof(this->rootLocalRankId_);
std::copy_n(it, sizeof(this->rootFdId_), reinterpret_cast<char*>(&this->rootFdId_));

freeRanges_.emplace_back(0, bufferSize_);
int rootPidFd = syscall(SYS_pidfd_open, rootPid_, 0);
if (rootPidFd < 0) {
throw mscclpp::SysError("pidfd_open() failed", errno);
}
int mcRootFileDescFd = syscall(SYS_pidfd_getfd, rootPidFd, mcFileDesc_, 0);
if (mcRootFileDescFd < 0) {
throw mscclpp::SysError("pidfd_getfd() failed", errno);
}
int mcRootFileDescFd = socketClient_.requestFd(UnixSocketServer::generateSocketPath(this->rootLocalRankId_), rootFdId_);
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&mcHandle_, reinterpret_cast<void*>(mcRootFileDescFd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(rootPidFd);
close(mcRootFileDescFd);

INFO(MSCCLPP_COLL, "NVLS handle was imported from root");
}

NvlsConnection::Impl::~Impl() {
// we don't need to free multicast handle object according to NCCL.
if (rootPid_ == getpid()) {
if (mcFileDesc_ >= 0) {
UnixSocketServer::instance().unregisterFd(rootFdId_);
close(mcFileDesc_);
mcFileDesc_ = -1;
}
}

Expand All @@ -123,8 +118,8 @@ std::vector<char> NvlsConnection::Impl::serialize() {
std::copy_n(reinterpret_cast<char*>(&bufferSize_), sizeof(bufferSize_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&minMcGran_), sizeof(minMcGran_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&mcGran_), sizeof(mcGran_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&rootPid_), sizeof(rootPid_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&mcFileDesc_), sizeof(mcFileDesc_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&rootLocalRankId_), sizeof(rootLocalRankId_), std::back_inserter(result));
std::copy_n(reinterpret_cast<char*>(&rootFdId_), sizeof(rootFdId_), std::back_inserter(result));
return result;
}

Expand Down
31 changes: 11 additions & 20 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "context.hpp"
#include "debug.h"
#include "serialization.hpp"
#include "unix_socket.hpp"
#include "utils_internal.hpp"

#define MSCCLPP_CULOG_WARN(cmd) \
Expand Down Expand Up @@ -66,12 +67,9 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports,
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
MSCCLPP_CUTHROW(cuMemExportToShareableHandle(transportInfo.shareableHandle, handle, getNvlsMemHandleType(), 0));
} else {
transportInfo.rootPid = getpid();
if (transportInfo.rootPid < 0) {
throw SysError("getpid() failed", errno);
}
MSCCLPP_CUTHROW(cuMemExportToShareableHandle(&transportInfo.fileDesc, handle, getNvlsMemHandleType(), 0));
this->fileDesc = transportInfo.fileDesc;
MSCCLPP_CUTHROW(cuMemExportToShareableHandle(&this->fileDesc, handle, getNvlsMemHandleType(), 0));
transportInfo.rootFdId = UnixSocketServer::instance().registerFd(fileDesc);
transportInfo.rootLocalRankId = env()->localRank;
}
transportInfo.offsetFromBase = (char*)data - (char*)baseDataPtr;
MSCCLPP_CUTHROW(cuMemRelease(handle));
Expand Down Expand Up @@ -139,8 +137,8 @@ MSCCLPP_API_CPP std::vector<char> RegisteredMemory::serialize() const {
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
detail::serialize(result, entry.shareableHandle);
} else {
detail::serialize(result, entry.rootPid);
detail::serialize(result, entry.fileDesc);
detail::serialize(result, entry.rootFdId);
detail::serialize(result, entry.rootLocalRankId);
}
detail::serialize(result, entry.offsetFromBase);
} else {
Expand Down Expand Up @@ -180,8 +178,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
it = detail::deserialize(it, transportInfo.shareableHandle);
} else {
it = detail::deserialize(it, transportInfo.rootPid);
it = detail::deserialize(it, transportInfo.fileDesc);
it = detail::deserialize(it, transportInfo.rootFdId);
it = detail::deserialize(it, transportInfo.rootLocalRankId);
}
it = detail::deserialize(it, transportInfo.offsetFromBase);
} else {
Expand Down Expand Up @@ -227,18 +225,11 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
if (getNvlsMemHandleType() == CU_MEM_HANDLE_TYPE_FABRIC) {
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsMemHandleType()));
} else {
int rootPidFd = syscall(SYS_pidfd_open, entry.rootPid, 0);
if (rootPidFd < 0) {
throw SysError("pidfd_open() failed", errno);
}
int fd = syscall(SYS_pidfd_getfd, rootPidFd, entry.fileDesc, 0);
if (fd < 0) {
throw SysError("pidfd_getfd() failed", errno);
}
INFO(MSCCLPP_P2P, "Get file descriptor %d from pidfd %d on peer 0x%lx", fd, rootPidFd, hostHash);
int fd = UnixSocketClient::instance().requestFd(UnixSocketServer::generateSocketPath(entry.rootLocalRankId),
entry.rootFdId);
INFO(MSCCLPP_P2P, "Get file descriptor %d from peer 0x%lx", fd, hostHash);
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast<void*>(fd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(rootPidFd);
close(fd);
}
}
Expand Down
Loading
Loading