@@ -118,7 +118,7 @@ static inline void mscclppNcclDlopenFinalize() {
118118}
119119
120120static inline int mscclppNcclInFallbackList (const char * collOps, const char * fallbackList) {
121- if (fallbackList == nullptr || fallbackList[ 0 ] == ' \0 ' || strcmp (fallbackList, " all" ) == 0 ) {
121+ if (strcmp (fallbackList, " all" ) == 0 ) {
122122 return 1 ;
123123 }
124124
@@ -207,6 +207,7 @@ struct ncclComm {
207207 uint32_t buffFlag;
208208
209209 int nRanksPerNode;
210+ int worldSize;
210211
211212 std::shared_ptr<uint32_t > deviceFlag7;
212213 std::shared_ptr<uint32_t > deviceFlag28;
@@ -703,10 +704,15 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
703704 commPtr->comm = mscclppComm;
704705 commPtr->executor = std::make_shared<mscclpp::Executor>(mscclppComm);
705706 commPtr->nRanksPerNode = mscclppComm->bootstrap ()->getNranksPerNode ();
707+ commPtr->worldSize = mscclppComm->bootstrap ()->getNranks ();
708+
709+ if (commPtr->worldSize == 1 ) {
710+ *comm = commPtr;
711+ return ncclSuccess;
712+ }
706713
707714 // FallBack for single node
708- if (mscclppComm->bootstrap ()->getNranks () == mscclppComm->bootstrap ()->getNranksPerNode ())
709- ncclCommInitRankFallbackSingleNode (commPtr, mscclppComm, rank);
715+ if (commPtr->worldSize == commPtr->nRanksPerNode ) ncclCommInitRankFallbackSingleNode (commPtr, mscclppComm, rank);
710716
711717 const std::string& collectiveDir = mscclpp::env ()->executionPlanDir ;
712718 if (collectiveDir != " " ) {
@@ -759,7 +765,12 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
759765 return ncclSuccess;
760766}
761767
762- NCCL_API ncclResult_t ncclCommInitAll (ncclComm_t*, int , const int *) {
768+ NCCL_API ncclResult_t ncclCommInitAll (ncclComm_t* comm, int ndev, const int *) {
769+ if (ndev == 1 ) {
770+ ncclUniqueId Id;
771+ ncclGetUniqueId (&Id);
772+ return ncclCommInitRank (comm, ndev, Id, 0 );
773+ }
763774 // TODO: implement this function
764775 WARN (" ncclCommInitAll is currently unavailable" );
765776 return ncclInternalError;
@@ -987,6 +998,14 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
987998
988999NCCL_API ncclResult_t ncclBroadcast (const void * sendbuff, void * recvbuff, size_t count, ncclDataType_t datatype,
9891000 int root, ncclComm_t comm, cudaStream_t stream) {
1001+ if (comm->worldSize == 1 ) {
1002+ if (sendbuff != recvbuff) {
1003+ size_t bytes = count * ncclTypeSize (datatype);
1004+ CUDACHECK (cudaMemcpyAsync (recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1005+ }
1006+ return ncclSuccess;
1007+ }
1008+
9901009 size_t bytes = count * ncclTypeSize (datatype);
9911010 if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr ) {
9921011 WARN (
@@ -996,7 +1015,7 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
9961015 }
9971016
9981017 int rank = comm->comm ->bootstrap ()->getRank ();
999- INFO (MSCCLPP_INIT , " rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p" , rank, sendbuff,
1018+ INFO (MSCCLPP_NCCL , " rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p" , rank, sendbuff,
10001019 recvbuff, count, datatype, comm);
10011020
10021021 const char * fallbackList = mscclpp::env ()->forceNcclFallbackOperation .c_str ();
@@ -1047,6 +1066,13 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
10471066
10481067NCCL_API ncclResult_t ncclAllReduce (const void * sendbuff, void * recvbuff, size_t count, ncclDataType_t datatype,
10491068 ncclRedOp_t reductionOperation, ncclComm_t comm, cudaStream_t stream) {
1069+ if (comm->worldSize == 1 ) {
1070+ if (sendbuff != recvbuff) {
1071+ size_t bytes = count * ncclTypeSize (datatype);
1072+ CUDACHECK (cudaMemcpyAsync (recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1073+ }
1074+ return ncclSuccess;
1075+ }
10501076 // Checking if the parameters are valids
10511077 if (sendbuff == nullptr || recvbuff == nullptr || count == 0 || ncclTypeSize (datatype) == 0 || comm == nullptr ) {
10521078 WARN (
@@ -1076,8 +1102,17 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
10761102 }
10771103 }
10781104
1079- if (plan == nullptr )
1105+ int nRanks = comm->comm ->bootstrap ()->getNranks ();
1106+ int nRanksPerNode = comm->comm ->bootstrap ()->getNranksPerNode ();
1107+ if (plan == nullptr && nRanks == nRanksPerNode)
10801108 return ncclAllReduceFallback (sendbuff, recvbuff, count, datatype, reductionOperation, comm, stream);
1109+ if (plan == nullptr && mscclppNcclDlopenSharedLib) {
1110+ return mscclppNcclOps.AllReduce (sendbuff, recvbuff, count, datatype, reductionOperation,
1111+ *reinterpret_cast <ncclComm_t*>(comm->mscclppNcclComm ), stream);
1112+ } else if (plan == nullptr ) {
1113+ WARN (" No FallBack code for AllReduce when multi-node" );
1114+ return ncclInternalError;
1115+ }
10811116
10821117 switch (datatype) {
10831118 case ncclFloat16:
@@ -1107,6 +1142,14 @@ NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t
11071142
11081143NCCL_API ncclResult_t ncclReduceScatter (const void * sendbuff, void * recvbuff, size_t recvcount, ncclDataType_t datatype,
11091144 ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
1145+ if (comm->worldSize == 1 ) {
1146+ if (sendbuff != recvbuff) {
1147+ size_t bytes = recvcount * ncclTypeSize (datatype);
1148+ CUDACHECK (cudaMemcpyAsync (recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1149+ }
1150+ return ncclSuccess;
1151+ }
1152+
11101153 size_t bytes = recvcount * ncclTypeSize (datatype);
11111154 if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr ) {
11121155 WARN (
@@ -1169,6 +1212,13 @@ NCCL_API ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, si
11691212
11701213NCCL_API ncclResult_t ncclAllGather (const void * sendbuff, void * recvbuff, size_t sendcount, ncclDataType_t datatype,
11711214 ncclComm_t comm, cudaStream_t stream) {
1215+ if (comm->worldSize == 1 ) {
1216+ if (sendbuff != recvbuff) {
1217+ size_t bytes = sendcount * ncclTypeSize (datatype);
1218+ CUDACHECK (cudaMemcpyAsync (recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1219+ }
1220+ return ncclSuccess;
1221+ }
11721222 size_t bytes = sendcount * ncclTypeSize (datatype);
11731223 if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr ) {
11741224 WARN (
@@ -1239,14 +1289,31 @@ NCCL_API ncclResult_t ncclRecv(void*, size_t, ncclDataType_t, int, ncclComm_t, c
12391289 return ncclInternalError;
12401290}
12411291
1242- NCCL_API ncclResult_t ncclAllToAll (const void *, void *, size_t , ncclDataType_t, ncclComm_t, cudaStream_t) {
1292+ NCCL_API ncclResult_t ncclAllToAll (const void * sendbuff, void * recvbuff, size_t count, ncclDataType_t datatype,
1293+ ncclComm_t comm, cudaStream_t stream) {
1294+ if (comm->worldSize == 1 ) {
1295+ if (sendbuff != recvbuff) {
1296+ size_t bytes = count * ncclTypeSize (datatype);
1297+ CUDACHECK (cudaMemcpyAsync (recvbuff, sendbuff, bytes, cudaMemcpyDeviceToDevice, stream));
1298+ }
1299+ return ncclSuccess;
1300+ }
12431301 // TODO: implement this function
12441302 WARN (" ncclAllToAll is currently unavailable" );
12451303 return ncclInternalError;
12461304}
12471305
1248- NCCL_API ncclResult_t ncclAllToAllv (const void *, const size_t [], const size_t [], void *, const size_t [], const size_t [],
1249- ncclDataType_t, ncclComm_t, cudaStream_t) {
1306+ NCCL_API ncclResult_t ncclAllToAllv (const void * sendbuff, [[maybe_unused]] const size_t sendcounts[],
1307+ const size_t sdispls[], void * recvbuff, const size_t recvcounts[],
1308+ const size_t rdispls[], ncclDataType_t datatype, ncclComm_t comm,
1309+ cudaStream_t stream) {
1310+ if (comm->worldSize == 1 ) {
1311+ size_t bytes = recvcounts[0 ] * ncclTypeSize (datatype);
1312+ MSCCLPP_CUDATHROW (cudaMemcpyAsync ((char *)recvbuff + rdispls[0 ] * ncclTypeSize (datatype),
1313+ (const char *)sendbuff + sdispls[0 ] * ncclTypeSize (datatype), bytes,
1314+ cudaMemcpyDeviceToDevice, stream));
1315+ return ncclSuccess;
1316+ }
12501317 WARN (" ncclAllToAllv is currently unavailable" );
12511318 return ncclInternalError;
12521319}
0 commit comments