Skip to content

Commit bf8d424

Browse files
Binyang2014Copilot
andauthored
use unix socket to share fd (#634)
Use unix socket to share fd to other processes. Used for nvls handle sharing Update nccl interface to support worldSize=1 --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 43f160c commit bf8d424

File tree

9 files changed

+516
-46
lines changed

9 files changed

+516
-46
lines changed

apps/nccl/src/nccl.cu

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ static inline void mscclppNcclDlopenFinalize() {
118118
}
119119

120120
static inline int mscclppNcclInFallbackList(const char* collOps, const char* fallbackList) {
121-
if (fallbackList == nullptr || fallbackList[0] == '\0' || strcmp(fallbackList, "all") == 0) {
121+
if (strcmp(fallbackList, "all") == 0) {
122122
return 1;
123123
}
124124

@@ -207,6 +207,7 @@ struct ncclComm {
207207
uint32_t buffFlag;
208208

209209
int nRanksPerNode;
210+
int worldSize;
210211

211212
std::shared_ptr<uint32_t> deviceFlag7;
212213
std::shared_ptr<uint32_t> deviceFlag28;
@@ -703,10 +704,15 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
703704
commPtr->comm = mscclppComm;
704705
commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);
705706
commPtr->nRanksPerNode = mscclppComm->bootstrap()->getNranksPerNode();
707+
commPtr->worldSize = mscclppComm->bootstrap()->getNranks();
708+
709+
if (commPtr->worldSize == 1) {
710+
*comm = commPtr;
711+
return ncclSuccess;
712+
}
706713

707714
// FallBack for single node
708-
if (mscclppComm->bootstrap()->getNranks() == mscclppComm->bootstrap()->getNranksPerNode())
709-
ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank);
715+
if (commPtr->worldSize == commPtr->nRanksPerNode) ncclCommInitRankFallbackSingleNode(commPtr, mscclppComm, rank);
710716

711717
const std::string& collectiveDir = mscclpp::env()->executionPlanDir;
712718
if (collectiveDir != "") {
@@ -759,7 +765,12 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
759765
return ncclSuccess;
760766
}
761767

762-
NCCL_API ncclResult_t ncclCommInitAll(ncclComm_t*, int, const int*) {
768+
NCCL_API ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int*) {
769+
if (ndev == 1) {
770+
ncclUniqueId Id;
771+
ncclGetUniqueId(&Id);
772+
return ncclCommInitRank(comm, ndev, Id, 0);
773+
}
763774
// TODO: implement this function
764775
WARN("ncclCommInitAll is currently unavailable");
765776
return ncclInternalError;
@@ -987,6 +998,14 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
987998

988999
NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
9891000
int root, ncclComm_t comm, cudaStream_t stream) {
1001+
if (comm->worldSize == 1) {
1002+
if (sendbuff != recvbuff) {
1003+
size_t bytes = count * ncclTypeSize(datatype);
1004+
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1005+
}
1006+
return ncclSuccess;
1007+
}
1008+
9901009
size_t bytes = count * ncclTypeSize(datatype);
9911010
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
9921011
WARN(
@@ -996,7 +1015,7 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
9961015
}
9971016

9981017
int rank = comm->comm->bootstrap()->getRank();
999-
INFO(MSCCLPP_INIT, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
1018+
INFO(MSCCLPP_NCCL, "rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p", rank, sendbuff,
10001019
recvbuff, count, datatype, comm);
10011020

10021021
const char* fallbackList = mscclpp::env()->forceNcclFallbackOperation.c_str();
@@ -1047,6 +1066,13 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
10471066

10481067
NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
10491068
ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
1069+
if (comm->worldSize == 1) {
1070+
if (sendbuff != recvbuff) {
1071+
size_t bytes = count * ncclTypeSize(datatype);
1072+
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1073+
}
1074+
return ncclSuccess;
1075+
}
10501076
// Checking if the parameters are valids
10511077
if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize(datatype) == 0 || comm == nullptr) {
10521078
WARN(
@@ -1076,8 +1102,17 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
10761102
}
10771103
}
10781104

1079-
if (plan == nullptr)
1105+
int nRanks = comm->comm->bootstrap()->getNranks();
1106+
int nRanksPerNode = comm->comm->bootstrap()->getNranksPerNode();
1107+
if (plan == nullptr && nRanks == nRanksPerNode)
10801108
return ncclAllReduceFallback(sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
1109+
if (plan == nullptr && mscclppNcclDlopenSharedLib) {
1110+
return mscclppNcclOps.AllReduce(sendbuff, recvbuff, count, datatype, reductionOperation,
1111+
*reinterpret_cast<ncclComm_t*>(comm->mscclppNcclComm), stream);
1112+
} else if (plan == nullptr) {
1113+
WARN("No FallBack code for AllReduce when multi-node");
1114+
return ncclInternalError;
1115+
}
10811116

10821117
switch (datatype) {
10831118
case ncclFloat16:
@@ -1107,6 +1142,14 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
11071142

11081143
NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, size_t recvcount, ncclDataType_t datatype,
11091144
ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
1145+
if (comm->worldSize == 1) {
1146+
if (sendbuff != recvbuff) {
1147+
size_t bytes = recvcount * ncclTypeSize(datatype);
1148+
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1149+
}
1150+
return ncclSuccess;
1151+
}
1152+
11101153
size_t bytes = recvcount * ncclTypeSize(datatype);
11111154
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
11121155
WARN(
@@ -1169,6 +1212,13 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
11691212

11701213
NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, ncclDataType_t datatype,
11711214
ncclComm_t comm, cudaStream_t stream) {
1215+
if (comm->worldSize == 1) {
1216+
if (sendbuff != recvbuff) {
1217+
size_t bytes = sendcount * ncclTypeSize(datatype);
1218+
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1219+
}
1220+
return ncclSuccess;
1221+
}
11721222
size_t bytes = sendcount * ncclTypeSize(datatype);
11731223
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) {
11741224
WARN(
@@ -1239,14 +1289,31 @@ NCCL_API ncclResult_t ncclRecv(void*, size_t, ncclDataType_t, int, ncclComm_t, c
12391289
return ncclInternalError;
12401290
}
12411291

1242-
NCCL_API ncclResult_t ncclAllToAll(const void*, void*, size_t, ncclDataType_t, ncclComm_t, cudaStream_t) {
1292+
NCCL_API ncclResult_t ncclAllToAll(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
1293+
ncclComm_t comm, cudaStream_t stream) {
1294+
if (comm->worldSize == 1) {
1295+
if (sendbuff != recvbuff) {
1296+
size_t bytes = count * ncclTypeSize(datatype);
1297+
CUDACHECK(cudaMemcpyAsync(recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1298+
}
1299+
return ncclSuccess;
1300+
}
12431301
// TODO: implement this function
12441302
WARN("ncclAllToAll is currently unavailable");
12451303
return ncclInternalError;
12461304
}
12471305

1248-
NCCL_API ncclResult_t ncclAllToAllv(const void*, const size_t[], const size_t[], void*, const size_t[], const size_t[],
1249-
ncclDataType_t, ncclComm_t, cudaStream_t) {
1306+
NCCL_API ncclResult_t ncclAllToAllv(const void* sendbuff, [[maybe_unused]] const size_t sendcounts[],
1307+
const size_t sdispls[], void* recvbuff, const size_t recvcounts[],
1308+
const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm,
1309+
cudaStream_t stream) {
1310+
if (comm->worldSize == 1) {
1311+
size_t bytes = recvcounts[0] * ncclTypeSize(datatype);
1312+
MSCCLPP_CUDATHROW(cudaMemcpyAsync((char*)recvbuff + rdispls[0] * ncclTypeSize(datatype),
1313+
(const char*)sendbuff + sdispls[0] * ncclTypeSize(datatype), bytes,
1314+
cudaMemcpyDeviceToDevice, stream));
1315+
return ncclSuccess;
1316+
}
12501317
WARN("ncclAllToAllv is currently unavailable");
12511318
return ncclInternalError;
12521319
}

src/bootstrap/bootstrap.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "api.h"
1515
#include "debug.h"
1616
#include "socket.h"
17+
#include "unix_socket.hpp"
1718
#include "utils_internal.hpp"
1819

1920
namespace mscclpp {
@@ -114,6 +115,8 @@ class TcpBootstrap::Impl {
114115
std::shared_ptr<Socket> getPeerSendSocket(int peer, int tag);
115116
std::shared_ptr<Socket> getPeerRecvSocket(int peer, int tag);
116117

118+
UnixSocketServer& unixSocketServer_;
119+
117120
static void assignPortToUniqueId(UniqueIdInternal& uniqueId);
118121
static void netInit(std::string ipPortPair, std::string interface, SocketAddress& netIfAddr);
119122

@@ -149,7 +152,8 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks)
149152
peerCommAddresses_(nRanks, SocketAddress()),
150153
barrierArr_(nRanks, 0),
151154
abortFlagStorage_(new uint32_t(0)),
152-
abortFlag_(abortFlagStorage_.get()) {}
155+
abortFlag_(abortFlagStorage_.get()),
156+
unixSocketServer_(UnixSocketServer::instance()) {}
153157

154158
UniqueId TcpBootstrap::Impl::getUniqueId() const { return getUniqueId(uniqueId_); }
155159

@@ -172,6 +176,9 @@ void TcpBootstrap::Impl::initialize(const UniqueId& uniqueId, int64_t timeoutSec
172176
SocketToString(&uniqueId_.addr, line);
173177
INFO(MSCCLPP_INIT, "rank %d nranks %d - connecting to %s", rank_, nRanks_, line);
174178
establishConnections(timeoutSec);
179+
180+
unixSocketServer_.start();
181+
INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_);
175182
}
176183

177184
void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t timeoutSec) {
@@ -204,6 +211,8 @@ void TcpBootstrap::Impl::initialize(const std::string& ifIpPortTrio, int64_t tim
204211
}
205212

206213
establishConnections(timeoutSec);
214+
unixSocketServer_.start();
215+
INFO(MSCCLPP_INIT, "rank %d - unix socket server started", rank_);
207216
}
208217

209218
TcpBootstrap::Impl::~Impl() {
@@ -567,6 +576,8 @@ void TcpBootstrap::Impl::close() {
567576
ringSendSocket_.reset(nullptr);
568577
peerSendSockets_.clear();
569578
peerRecvSockets_.clear();
579+
unixSocketServer_.stop();
580+
UnixSocketClient::instance().reset();
570581
}
571582

572583
MSCCLPP_API_CPP UniqueId TcpBootstrap::createUniqueId() { return Impl::createUniqueId(); }

src/include/registered_memory.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ struct TransportInfo {
3232
char shareableHandle[64];
3333
struct {
3434
// These are only defined for multicast (NVLS) capability
35-
pid_t rootPid;
36-
int fileDesc;
35+
int rootFd;
36+
int rootPid;
3737
};
3838
};
3939
size_t offsetFromBase;

src/include/unix_socket.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT License.
3+
4+
#ifndef MSCCLPP_UNIX_SOCKET_HPP_
5+
#define MSCCLPP_UNIX_SOCKET_HPP_
6+
7+
#include <cstdint>
8+
#include <memory>
9+
#include <mutex>
10+
#include <string>
11+
#include <thread>
12+
#include <unordered_map>
13+
#include <unordered_set>
14+
15+
namespace mscclpp {
16+
17+
class UnixSocketServer {
18+
public:
19+
static UnixSocketServer& instance();
20+
static std::string generateSocketPath(int socketId);
21+
22+
void start();
23+
void stop();
24+
int registerFd(int fd);
25+
void unregisterFd(int fd);
26+
std::string getSocketPath() const;
27+
28+
private:
29+
int listenUnixSockFd_ = -1;
30+
std::string listenUnixSockPath_;
31+
std::thread mainThread_;
32+
std::unique_ptr<uint32_t> abortFlagStorage_;
33+
volatile uint32_t* abortFlag_;
34+
std::mutex mutex_;
35+
std::unordered_set<int> fdSet_;
36+
37+
UnixSocketServer();
38+
void mainLoop(int listenUnixSockFd);
39+
};
40+
41+
class UnixSocketClient {
42+
public:
43+
static UnixSocketClient& instance();
44+
45+
int requestFd(const std::string& socketPath, uint32_t fdId);
46+
void reset();
47+
~UnixSocketClient();
48+
49+
private:
50+
std::unordered_map<std::string, int> cachedFds_;
51+
std::mutex mutex_;
52+
53+
UnixSocketClient() = default;
54+
int requestFdInternal(int connFd, uint32_t fdId);
55+
};
56+
57+
} // namespace mscclpp
58+
59+
#endif // MSCCLPP_UNIX_SOCKET_HPP_

src/nvls.cc

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "api.h"
1313
#include "debug.h"
1414
#include "endpoint.hpp"
15+
#include "unix_socket.hpp"
1516

1617
namespace mscclpp {
1718

@@ -42,9 +43,12 @@ class NvlsConnection::Impl : public std::enable_shared_from_this<NvlsConnection:
4243
size_t minMcGran_;
4344
size_t mcGran_;
4445
// These are only defined for multicast (NVLS) capability
45-
pid_t rootPid_;
46+
int rootFd_;
47+
int rootPid_;
4648
int mcFileDesc_;
4749

50+
UnixSocketClient& socketClient_ = UnixSocketClient::instance();
51+
4852
std::list<std::pair<size_t, size_t>> allocatedRanges_;
4953
std::list<std::pair<size_t, size_t>> freeRanges_;
5054
};
@@ -67,11 +71,8 @@ NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) {
6771
MSCCLPP_CUTHROW(
6872
cuMemExportToShareableHandle(&mcFileDesc_, mcHandle_, CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
6973
freeRanges_.emplace_back(0, bufferSize_);
70-
7174
rootPid_ = getpid();
72-
if (rootPid_ < 0) {
73-
throw mscclpp::SysError("getpid() failed", errno);
74-
}
75+
rootFd_ = UnixSocketServer::instance().registerFd(mcFileDesc_);
7576

7677
INFO(MSCCLPP_COLL,
7778
"NVLS handle created on root with size %ld. minGranularity %ld and recommendedGranularity %ld buffer size is "
@@ -91,29 +92,23 @@ NvlsConnection::Impl::Impl(const std::vector<char>& data) {
9192
it += sizeof(this->mcGran_);
9293
std::copy_n(it, sizeof(this->rootPid_), reinterpret_cast<char*>(&this->rootPid_));
9394
it += sizeof(this->rootPid_);
94-
std::copy_n(it, sizeof(this->mcFileDesc_), reinterpret_cast<char*>(&this->mcFileDesc_));
95+
std::copy_n(it, sizeof(this->rootFd_), reinterpret_cast<char*>(&this->rootFd_));
9596

9697
freeRanges_.emplace_back(0, bufferSize_);
97-
int rootPidFd = syscall(SYS_pidfd_open, rootPid_, 0);
98-
if (rootPidFd < 0) {
99-
throw mscclpp::SysError("pidfd_open() failed", errno);
100-
}
101-
int mcRootFileDescFd = syscall(SYS_pidfd_getfd, rootPidFd, mcFileDesc_, 0);
102-
if (mcRootFileDescFd < 0) {
103-
throw mscclpp::SysError("pidfd_getfd() failed", errno);
104-
}
98+
int mcRootFileDescFd = socketClient_.requestFd(UnixSocketServer::generateSocketPath(this->rootPid_), rootFd_);
10599
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&mcHandle_, reinterpret_cast<void*>(mcRootFileDescFd),
106100
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
107-
close(rootPidFd);
108101
close(mcRootFileDescFd);
109102

110103
INFO(MSCCLPP_COLL, "NVLS handle was imported from root");
111104
}
112105

113106
NvlsConnection::Impl::~Impl() {
114107
// we don't need to free multicast handle object according to NCCL.
115-
if (rootPid_ == getpid()) {
108+
if (mcFileDesc_ >= 0) {
109+
UnixSocketServer::instance().unregisterFd(rootFd_);
116110
close(mcFileDesc_);
111+
mcFileDesc_ = -1;
117112
}
118113
}
119114

@@ -124,7 +119,7 @@ std::vector<char> NvlsConnection::Impl::serialize() {
124119
std::copy_n(reinterpret_cast<char*>(&minMcGran_), sizeof(minMcGran_), std::back_inserter(result));
125120
std::copy_n(reinterpret_cast<char*>(&mcGran_), sizeof(mcGran_), std::back_inserter(result));
126121
std::copy_n(reinterpret_cast<char*>(&rootPid_), sizeof(rootPid_), std::back_inserter(result));
127-
std::copy_n(reinterpret_cast<char*>(&mcFileDesc_), sizeof(mcFileDesc_), std::back_inserter(result));
122+
std::copy_n(reinterpret_cast<char*>(&rootFd_), sizeof(rootFd_), std::back_inserter(result));
128123
return result;
129124
}
130125

0 commit comments

Comments
 (0)