@@ -58,6 +58,14 @@ typedef struct _mscclppNcclOps_t {
5858 ncclComm_t comm, cudaStream_t stream);
5959 ncclResult_t (*ReduceScatter)(const void * sendbuff, void * recvbuff, size_t recvcount, ncclDataType_t datatype,
6060 ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
61+ ncclResult_t (*Reduce)(const void * sendbuff, void * recvbuff, size_t count, ncclDataType_t datatype, ncclRedOp_t op,
62+ int root, ncclComm_t comm, cudaStream_t stream);
63+ ncclResult_t (*Send)(const void * sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm,
64+ cudaStream_t stream);
65+ ncclResult_t (*Recv)(void * recvbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm,
66+ cudaStream_t stream);
67+ ncclResult_t (*GroupStart)();
68+ ncclResult_t (*GroupEnd)();
6169} mscclppNcclOps_t;
6270
6371mscclppNcclOps_t mscclppNcclOps;
@@ -106,6 +114,14 @@ static inline int mscclppNcclDlopenInit() {
106114 ncclResult_t (*)(const void *, void *, size_t , ncclDataType_t, int , ncclComm_t, cudaStream_t));
107115 NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, ReduceScatter,
108116 ncclResult_t (*)(const void *, void *, size_t , ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t));
117+ NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, Reduce,
118+ ncclResult_t (*)(const void *, void *, size_t , ncclDataType_t, ncclRedOp_t, int , ncclComm_t, cudaStream_t));
119+ NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, Send,
120+ ncclResult_t (*)(const void *, size_t , ncclDataType_t, int , ncclComm_t, cudaStream_t));
121+ NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, Recv,
122+ ncclResult_t (*)(void *, size_t , ncclDataType_t, int , ncclComm_t, cudaStream_t));
123+ NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, GroupStart, ncclResult_t (*)());
124+ NCCL_DLSYM (mscclppNcclOps, mscclppNcclDlHandle, nccl, GroupEnd, ncclResult_t (*)());
109125
110126 return dlopenSuccess;
111127}
@@ -135,6 +151,17 @@ static inline int mscclppNcclInFallbackList(const char* collOps, const char* fal
135151 return 0 ;
136152}
137153
154+ static bool tryLoadNcclSharedLib () {
155+ if (mscclppNcclDlopenSharedLib) return true ;
156+ if (!mscclpp::env ()->ncclSharedLibPath .empty ()) {
157+ if (mscclppNcclDlopenInit () == dlopenSuccess) {
158+ mscclppNcclDlopenSharedLib = true ;
159+ return true ;
160+ }
161+ }
162+ return false ;
163+ }
164+
138165// Declare the global map to store associations between raw pointer and shared pointer
139166static std::unordered_map<void *, std::shared_ptr<char >> ptrMap;
140167
@@ -261,15 +288,6 @@ static mscclpp::Algorithm algoSelector(
261288 mscclpp::isCuMemMapAllocated (const_cast <void *>(input)) && mscclpp::isCuMemMapAllocated (output);
262289 bool mscclppDisableChannelCache = mscclpp::env ()->disableChannelCache ;
263290 bool useNvlsWithZeroCopy = mscclpp::isNvlsSupported () && !mscclppDisableChannelCache && isCuMemMapAllocated;
264- if (collective == " broadcast" ) {
265- #if defined(__HIP_PLATFORM_AMD__)
266- return algoMapByCollective.at (collective).at (" default_broadcast6" );
267- #else
268- if (!mscclppNcclDlopenSharedLib) {
269- return algoMapByCollective.at (collective).at (" default_broadcast6" );
270- }
271- #endif
272- }
273291 if (collective == " allgather" ) {
274292 if (messageSize <= 32 * (1 << 20 )) {
275293 return algoMapByCollective.at (collective).at (" default_allgather6" );
@@ -358,11 +376,8 @@ NCCL_API ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueI
358376#endif
359377
360378 const std::string ncclLibPath = mscclpp::env ()->ncclSharedLibPath ;
361- if (!ncclLibPath.empty ()) {
362- int dlopenStatus = mscclppNcclDlopenInit ();
363- if (dlopenStatus == dlopenSuccess) {
364- mscclppNcclDlopenSharedLib = true ;
365- } else {
379+ if (!ncclLibPath.empty () && !mscclppNcclDlopenSharedLib) {
380+ if (!tryLoadNcclSharedLib ()) {
366381 WARN (" Failed to load the shared library for nccl/rccl" );
367382 return ncclInternalError;
368383 }
@@ -559,9 +574,13 @@ NCCL_API ncclResult_t ncclRedOpDestroy(ncclRedOp_t, ncclComm_t) {
559574 return ncclInternalError;
560575}
561576
562- NCCL_API ncclResult_t ncclReduce (const void *, void *, size_t , ncclDataType_t, ncclRedOp_t, int , ncclComm_t ,
563- cudaStream_t) {
577+ NCCL_API ncclResult_t ncclReduce (const void * sendbuff , void * recvbuff , size_t count , ncclDataType_t datatype ,
578+ ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream ) {
564579 // TODO: implement this function
580+ if (mscclppNcclDlopenSharedLib == true ) {
581+ return mscclppNcclOps.Reduce (sendbuff, recvbuff, count, datatype, op, root,
582+ *reinterpret_cast <ncclComm_t*>(comm->mscclppNcclComm ), stream);
583+ }
565584 WARN (" ncclReduce is currently unavailable" );
566585 return ncclInternalError;
567586}
@@ -580,15 +599,14 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t
580599 }
581600 return ncclSuccess;
582601 }
583-
584- if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr ) {
602+ int rank = comm-> comm -> bootstrap ()-> getRank ();
603+ if (( sendbuff == nullptr && root == rank) || recvbuff == nullptr || bytes == 0 || comm == nullptr ) {
585604 WARN (
586605 " One or more of the following conditions is met: sendbuff or recvbuff pointer is nullptr, bytes is 0, "
587606 " or comm is nullptr." );
588607 return ncclInvalidArgument;
589608 }
590609
591- int rank = comm->comm ->bootstrap ()->getRank ();
592610 INFO (MSCCLPP_NCCL, " rank %d broadcast sendbuff %p recvbuff %p count %ld, dtype %d, comm: %p" , rank, sendbuff,
593611 recvbuff, count, datatype, comm);
594612
@@ -807,14 +825,22 @@ NCCL_API ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t
807825 return ncclInvalidUsage;
808826}
809827
810- NCCL_API ncclResult_t ncclSend (const void *, size_t , ncclDataType_t, int , ncclComm_t, cudaStream_t) {
811- // TODO: implement this function
828+ NCCL_API ncclResult_t ncclSend (const void * sendbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm,
829+ cudaStream_t stream) {
830+ if (mscclppNcclDlopenSharedLib == true ) {
831+ return mscclppNcclOps.Send (sendbuff, count, datatype, peer, *reinterpret_cast <ncclComm_t*>(comm->mscclppNcclComm ),
832+ stream);
833+ }
812834 WARN (" ncclSend is currently unavailable" );
813835 return ncclInternalError;
814836}
815837
816- NCCL_API ncclResult_t ncclRecv (void *, size_t , ncclDataType_t, int , ncclComm_t, cudaStream_t) {
817- // TODO: implement this function
838+ NCCL_API ncclResult_t ncclRecv (void * recvbuff, size_t count, ncclDataType_t datatype, int peer, ncclComm_t comm,
839+ cudaStream_t stream) {
840+ if (mscclppNcclDlopenSharedLib == true ) {
841+ return mscclppNcclOps.Recv (recvbuff, count, datatype, peer, *reinterpret_cast <ncclComm_t*>(comm->mscclppNcclComm ),
842+ stream);
843+ }
818844 WARN (" ncclRecv is currently unavailable" );
819845 return ncclInternalError;
820846}
@@ -849,13 +875,21 @@ NCCL_API ncclResult_t ncclAllToAllv(const void* sendbuff, [[maybe_unused]] const
849875}
850876
851877NCCL_API ncclResult_t ncclGroupStart () {
852- // TODO: Do nothing for now
878+ if (!tryLoadNcclSharedLib ()) {
879+ WARN (" Failed to load the shared library for nccl/rccl" );
880+ return ncclInternalError;
881+ }
882+ if (mscclppNcclDlopenSharedLib == true ) {
883+ return mscclppNcclOps.GroupStart ();
884+ }
853885 WARN (" ncclGroupStart is currently unavailable, return success" );
854886 return ncclSuccess;
855887}
856888
857889NCCL_API ncclResult_t ncclGroupEnd () {
858- // TODO: Do nothing for now
890+ if (mscclppNcclDlopenSharedLib == true ) {
891+ return mscclppNcclOps.GroupEnd ();
892+ }
859893 WARN (" ncclGroupEnd is currently unavailable, return success" );
860894 return ncclSuccess;
861895}
0 commit comments