@@ -199,6 +199,7 @@ struct ncclComm {
199
199
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
200
200
std::unordered_map<channelKey, NvlsChannelInfo> channelNvlsInfos;
201
201
std::shared_ptr<char > scratchBuff;
202
+ mscclpp::RegisteredMemory registeredScratchMemory;
202
203
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
203
204
std::vector<ChannelInfo> channelInfos;
204
205
@@ -268,30 +269,29 @@ static Op getReduceOp(ncclRedOp_t op) {
268
269
}
269
270
270
271
static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories (std::shared_ptr<mscclpp::Communicator> comm, int rank,
271
- void * buff, size_t bytes,
272
- mscclpp::TransportFlags transport) {
272
+ mscclpp::RegisteredMemory localMemory) {
273
273
std::vector<mscclpp::RegisteredMemory> remoteMemories;
274
- mscclpp::RegisteredMemory memory = comm->registerMemory (buff, bytes, transport);
275
274
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
276
275
for (int i = 0 ; i < comm->bootstrap ()->getNranks (); i++) {
277
276
if (i == rank) continue ;
278
277
remoteRegMemoryFutures.push_back (comm->recvMemory (i));
279
- comm->sendMemory (memory , i);
278
+ comm->sendMemory (localMemory , i);
280
279
}
281
280
std::transform (remoteRegMemoryFutures.begin (), remoteRegMemoryFutures.end (), std::back_inserter (remoteMemories),
282
281
[](const auto & future) { return future.get (); });
283
282
return remoteMemories;
284
283
}
285
284
286
285
static std::vector<mscclpp::MemoryChannel> setupMemoryChannels (
287
- ncclComm_t comm, const std::vector<mscclpp::RegisteredMemory>& remoteMemories, void * src) {
286
+ ncclComm_t comm, const std::vector<mscclpp::RegisteredMemory>& remoteMemories,
287
+ mscclpp::RegisteredMemory localMemory) {
288
288
std::vector<mscclpp::MemoryChannel> channels;
289
289
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores = comm->memorySemaphores ;
290
290
size_t nConnections = comm->connections .size ();
291
291
for (size_t idx = 0 ; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) {
292
292
for (size_t cid = 0 ; cid < nConnections; ++cid) {
293
293
if (comm->connections [cid]->transport () == mscclpp::Transport::CudaIpc) {
294
- channels.emplace_back (memorySemaphores[idx * nConnections + cid], remoteMemories[cid], src , nullptr );
294
+ channels.emplace_back (memorySemaphores[idx * nConnections + cid], remoteMemories[cid], localMemory , nullptr );
295
295
}
296
296
}
297
297
}
@@ -432,8 +432,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
432
432
if (count * ncclTypeSize (datatype) <= (1 << 20 ) || mscclpp::isNvlsSupported ()) {
433
433
auto sendIt = comm->channelScratchInfos .find (sendKey);
434
434
if (sendIt == comm->channelScratchInfos .end ()) {
435
+ mscclpp::RegisteredMemory localMemory =
436
+ comm->comm ->registerMemory ((void *)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
435
437
std::vector<mscclpp::MemoryChannel> channels =
436
- setupMemoryChannels (comm, comm->remoteScratchRegMemories , const_cast < void *>(( void *)sendBasePtr) );
438
+ setupMemoryChannels (comm, comm->remoteScratchRegMemories , localMemory );
437
439
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles (channels)};
438
440
sendIt = comm->channelScratchInfos .emplace (sendKey, channelInfo).first ;
439
441
}
@@ -444,8 +446,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
444
446
445
447
auto sendIt = comm->channelInInfos .find (sendKey);
446
448
if (sendIt == comm->channelInInfos .end ()) {
449
+ mscclpp::RegisteredMemory localMemory =
450
+ comm->comm ->registerMemory ((void *)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
447
451
std::vector<mscclpp::MemoryChannel> channels =
448
- setupMemoryChannels (comm, comm->remoteScratchRegMemories , const_cast < void *>(( void *)sendBasePtr) );
452
+ setupMemoryChannels (comm, comm->remoteScratchRegMemories , localMemory );
449
453
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles (channels)};
450
454
sendIt = comm->channelInInfos .emplace (sendKey, channelInfo).first ;
451
455
}
@@ -457,10 +461,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
457
461
recvBasePtr = (CUdeviceptr)recvbuff;
458
462
offsetOut = 0 ;
459
463
}
460
- remoteMemories =
461
- setupRemoteMemories ( comm->comm , rank, (void *)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
462
- std::vector<mscclpp::MemoryChannel> outChannels =
463
- setupMemoryChannels (comm, remoteMemories, const_cast < void *>(( void *)recvBasePtr) );
464
+ mscclpp::RegisteredMemory localMemory =
465
+ comm->comm -> registerMemory ( (void *)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
466
+ remoteMemories = setupRemoteMemories (comm-> comm , rank, localMemory);
467
+ std::vector<mscclpp::MemoryChannel> outChannels = setupMemoryChannels (comm, remoteMemories, localMemory );
464
468
ChannelInfo channelInfo{outChannels, setupMemoryChannelDeviceHandles (outChannels)};
465
469
recvIt = comm->channelOutInfos .emplace (recvKey, channelInfo).first ;
466
470
if (mscclppDisableChannelCache == true ) {
@@ -552,10 +556,10 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
552
556
recvBasePtr = (CUdeviceptr)recvbuff;
553
557
offsetOut = 0 ;
554
558
}
555
- std::vector< mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories (
556
- comm->comm , rank, const_cast < void *> ((void *)recvBasePtr) , recvBytes, mscclpp::Transport::CudaIpc);
557
- std::vector<mscclpp::MemoryChannel> channels =
558
- setupMemoryChannels (comm, remoteMemories, const_cast < void *>(( void *)recvBasePtr) );
559
+ mscclpp::RegisteredMemory localMemory =
560
+ comm->comm -> registerMemory ((void *)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
561
+ std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories (comm-> comm , rank, localMemory);
562
+ std::vector<mscclpp::MemoryChannel> channels = setupMemoryChannels (comm, remoteMemories, localMemory );
559
563
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
560
564
std::transform (channels.begin (), channels.end (), std::back_inserter (memoryChannelDeviceHandles),
561
565
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle (memoryChannel); });
@@ -577,8 +581,10 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
577
581
#else
578
582
auto sendIt = comm->channelInInfos .find (sendKey);
579
583
if (sendIt == comm->channelInInfos .end ()) {
584
+ mscclpp::RegisteredMemory localMemory =
585
+ comm->comm ->registerMemory ((void *)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
580
586
std::vector<mscclpp::MemoryChannel> channels =
581
- setupMemoryChannels (comm, comm->remoteScratchRegMemories , const_cast < void *>(( void *)sendBasePtr) );
587
+ setupMemoryChannels (comm, comm->remoteScratchRegMemories , localMemory );
582
588
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles (channels)};
583
589
sendIt = comm->channelInInfos .emplace (sendKey, channelInfo).first ;
584
590
}
@@ -629,8 +635,9 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
629
635
commPtr->buffFlag = 0 ;
630
636
commPtr->numScratchBuff = 2 ;
631
637
commPtr->scratchBuff = mscclpp::GpuBuffer<char >(SCRATCH_SIZE).memory ();
632
- commPtr->remoteScratchRegMemories =
633
- setupRemoteMemories (commPtr->comm , rank, commPtr->scratchBuff .get (), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
638
+ commPtr->registeredScratchMemory =
639
+ commPtr->comm ->registerMemory (commPtr->scratchBuff .get (), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
640
+ commPtr->remoteScratchRegMemories = setupRemoteMemories (commPtr->comm , rank, commPtr->registeredScratchMemory );
634
641
635
642
commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t >(7 );
636
643
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t >(28 );
@@ -935,12 +942,10 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
935
942
936
943
auto it = comm->channelOutInfos .find (recvKey);
937
944
if (it == comm->channelOutInfos .end ()) {
938
- // std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
939
- // comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
940
- // std::vector<mscclpp::MemoryChannel> channels =
941
- // setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
945
+ mscclpp::RegisteredMemory localMemory =
946
+ comm->comm ->registerMemory ((void *)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
942
947
std::vector<mscclpp::MemoryChannel> channels =
943
- setupMemoryChannels (comm, comm->remoteScratchRegMemories , const_cast < void *>(( void *)recvBasePtr) );
948
+ setupMemoryChannels (comm, comm->remoteScratchRegMemories , localMemory );
944
949
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
945
950
std::transform (channels.begin (), channels.end (), std::back_inserter (memoryChannelDeviceHandles),
946
951
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle (memoryChannel); });
0 commit comments