1+ // Copyright (c) Microsoft Corporation.
2+ // Licensed under the MIT license.
3+
4+ #include < mscclpp/nccl.h>
5+
6+ #include < mscclpp/algorithm.hpp>
7+ #include < mscclpp/env.hpp>
8+ #include < mscclpp/gpu_utils.hpp>
9+
10+ #include " allgather.hpp"
11+ #include " debug.h"
12+
13+ AllgatherAlgo6::AllgatherAlgo6 () : disableChannelCache_(false ) {
14+ if (mscclpp::env ()->disableChannelCache ) {
15+ disableChannelCache_ = true ;
16+ }
17+ }
18+
19+ void AllgatherAlgo6::initialize (std::shared_ptr<mscclpp::Communicator> comm,
20+ std::unordered_map<std::string, std::shared_ptr<void >>&) {
21+ this ->conns_ = setupConnections (comm);
22+ }
23+
24+ ncclResult_t AllgatherAlgo6::allgatherKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input,
25+ void * output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
26+ std::unordered_map<std::string, std::shared_ptr<void >>&) {
27+ int nBlocks = 28 ;
28+ const size_t bytes = count * ncclTypeSize (dtype);
29+ const size_t nElem = bytes / sizeof (int );
30+ int rank = ctx->rank ;
31+ if (bytes <= 32 * (1 << 20 )) {
32+ if (nElem <= 4096 ) {
33+ nBlocks = 7 ;
34+ } else if (nElem <= 32768 ) {
35+ nBlocks = 14 ;
36+ } else if (nElem >= 2097152 ) {
37+ nBlocks = 35 ;
38+ }
39+ } else {
40+ nBlocks = 35 ;
41+ }
42+
43+ size_t channelOutOffset = *static_cast <size_t *>(ctx->extras [" channel_out_offset" ].get ());
44+ if ((char *)input == (char *)output + rank * bytes) {
45+ allgather6<false ><<<nBlocks, 1024 , 0 , stream>>> ((void *)input, ctx->memoryChannelDeviceHandles .get (),
46+ channelOutOffset, ctx->rank , ctx->workSize , ctx->nRanksPerNode ,
47+ nElem);
48+ } else {
49+ allgather6<true ><<<nBlocks, 1024 , 0 , stream>>> ((void *)input, ctx->memoryChannelDeviceHandles .get (),
50+ channelOutOffset, ctx->rank , ctx->workSize , ctx->nRanksPerNode ,
51+ nElem);
52+ }
53+ cudaError_t err = cudaGetLastError ();
54+ if (err != cudaSuccess) {
55+ WARN (" AllgatherAlgo6 failed with error %d" , err);
56+ return ncclInternalError;
57+ }
58+ return ncclSuccess;
59+ }
60+
61+ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext (std::shared_ptr<mscclpp::Communicator> comm,
62+ const void *, void * output, size_t count,
63+ ncclDataType_t dtype) {
64+ constexpr int nChannelsPerConnection = 35 ;
65+
66+ auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
67+ ctx->rank = comm->bootstrap ()->getRank ();
68+ ctx->workSize = comm->bootstrap ()->getNranks ();
69+ ctx->nRanksPerNode = comm->bootstrap ()->getNranksPerNode ();
70+
71+ // setup semaphores
72+ ctx->memorySemaphores = std::move (setupMemorySemaphores (comm, this ->conns_ , nChannelsPerConnection));
73+
74+ size_t bytes = count * ncclTypeSize (dtype);
75+ size_t recvBytes;
76+ CUdeviceptr recvBasePtr;
77+ MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvBytes, (CUdeviceptr)output));
78+ size_t channelOutOffset = (char *)output - (char *)recvBasePtr;
79+ if (disableChannelCache_) {
80+ channelOutOffset = 0 ;
81+ recvBytes = bytes;
82+ recvBasePtr = (CUdeviceptr)output;
83+ }
84+ ctx->extras .insert ({" channel_out_offset" , std::make_shared<size_t >(channelOutOffset)});
85+
86+ // register the memory for the broadcast operation
87+ mscclpp::RegisteredMemory localMemory =
88+ comm->registerMemory ((void *)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
89+ std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories (comm, ctx->rank , localMemory);
90+ ctx->memoryChannels = std::move (
91+ setupMemoryChannels (this ->conns_ , ctx->memorySemaphores , remoteMemories, localMemory, nChannelsPerConnection));
92+ ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles (ctx->memoryChannels );
93+
94+ // keep registered memories reference
95+ ctx->registeredMemories = std::move (remoteMemories);
96+ ctx->registeredMemories .push_back (localMemory);
97+
98+ return ctx;
99+ }
100+
101+ mscclpp::AlgorithmCtxKey AllgatherAlgo6::generateAllgatherContextKey (const void *, void * output, size_t ,
102+ ncclDataType_t) {
103+ static int tag = 0 ;
104+ if (disableChannelCache_) {
105+ // always return a new key if channel cache is disabled
106+ return mscclpp::AlgorithmCtxKey{nullptr , nullptr , 0 , 0 , tag++};
107+ }
108+ size_t recvBytes;
109+ CUdeviceptr recvBasePtr;
110+ MSCCLPP_CUTHROW (cuMemGetAddressRange (&recvBasePtr, &recvBytes, (CUdeviceptr)output));
111+ return mscclpp::AlgorithmCtxKey{nullptr , (void *)recvBasePtr, 0 , recvBytes, 0 };
112+ }
113+
114+ mscclpp::Algorithm AllgatherAlgo6::build () {
115+ auto self = std::make_shared<AllgatherAlgo6>();
116+ mscclpp::Algorithm allgatherAlgo (
117+ " default_allgather6" , " allgather" ,
118+ [self](std::shared_ptr<mscclpp::Communicator> comm,
119+ std::unordered_map<std::string, std::shared_ptr<void >>& extras) { self->initialize (comm, extras); },
120+ [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count, int dtype,
121+ cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
122+ return self->allgatherKernelFunc (ctx, input, output, count, static_cast <ncclDataType_t>(dtype), stream, extras);
123+ },
124+ [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count, int dtype) {
125+ return self->initAllgatherContext (comm, input, output, count, static_cast <ncclDataType_t>(dtype));
126+ },
127+ [self](const void * input, void * output, size_t count, int dtype) {
128+ return self->generateAllgatherContextKey (input, output, count, static_cast <ncclDataType_t>(dtype));
129+ });
130+ return allgatherAlgo;
131+ }
132+
133+ void AllgatherAlgo8::initialize (std::shared_ptr<mscclpp::Communicator> comm,
134+ std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
135+ this ->conns_ = setupConnections (comm);
136+ this ->scratchBuffer_ = std::static_pointer_cast<char >(extras.at (" scratch" ));
137+ this ->scratchBufferSize_ = *(size_t *)(extras.at (" scratch_size" ).get ());
138+ }
139+
140+ ncclResult_t AllgatherAlgo8::allgatherKernelFunc (const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input,
141+ void * output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
142+ std::unordered_map<std::string, std::shared_ptr<void >>&) {
143+ int rank = ctx->rank ;
144+ const size_t bytes = count * ncclTypeSize (dtype);
145+ const size_t nElem = bytes / sizeof (int );
146+ if ((char *)input == (char *)output + rank * bytes) {
147+ allgather8<false ><<<56 , 1024 , 0 , stream>>> ((void *)input, this ->scratchBuffer_ .get (), (void *)output,
148+ ctx->memoryChannelDeviceHandles .get (), rank, ctx->nRanksPerNode ,
149+ ctx->workSize , nElem);
150+ } else {
151+ allgather8<true ><<<56 , 1024 , 0 , stream>>> ((void *)input, this ->scratchBuffer_ .get (), (void *)output,
152+ ctx->memoryChannelDeviceHandles .get (), rank, ctx->nRanksPerNode ,
153+ ctx->workSize , nElem);
154+ }
155+ cudaError_t err = cudaGetLastError ();
156+ if (err != cudaSuccess) {
157+ WARN (" AllgatherAlgo8 failed with error %d" , err);
158+ return ncclInternalError;
159+ }
160+ return ncclSuccess;
161+ }
162+
163+ std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext (std::shared_ptr<mscclpp::Communicator> comm,
164+ const void * input, void *, size_t count,
165+ ncclDataType_t dtype) {
166+ constexpr int nChannelsPerConnection = 56 ;
167+
168+ auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
169+ ctx->rank = comm->bootstrap ()->getRank ();
170+ ctx->workSize = comm->bootstrap ()->getNranks ();
171+ ctx->nRanksPerNode = comm->bootstrap ()->getNranksPerNode ();
172+
173+ // setup semaphores
174+ ctx->memorySemaphores = std::move (setupMemorySemaphores (comm, this ->conns_ , nChannelsPerConnection));
175+
176+ size_t bytes = count * ncclTypeSize (dtype);
177+ // register the memory for the broadcast operation
178+ mscclpp::RegisteredMemory localMemory = comm->registerMemory ((void *)input, bytes, mscclpp::Transport::CudaIpc);
179+ mscclpp::RegisteredMemory scratchMemory =
180+ comm->registerMemory (this ->scratchBuffer_ .get (), scratchBufferSize_, mscclpp::Transport::CudaIpc);
181+ std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories (comm, ctx->rank , scratchMemory);
182+
183+ // setup channels
184+ ctx->memoryChannels = std::move (
185+ setupMemoryChannels (this ->conns_ , ctx->memorySemaphores , remoteMemories, localMemory, nChannelsPerConnection));
186+ ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles (ctx->memoryChannels );
187+
188+ // keep registered memories reference
189+ ctx->registeredMemories = std::move (remoteMemories);
190+ ctx->registeredMemories .push_back (localMemory);
191+ ctx->registeredMemories .push_back (scratchMemory);
192+
193+ return ctx;
194+ }
195+
196+ mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey (const void *, void *, size_t , ncclDataType_t) {
197+ // always return same key, non-zero copy algo
198+ return mscclpp::AlgorithmCtxKey{nullptr , nullptr , 0 , 0 , 0 };
199+ }
200+
201+ mscclpp::Algorithm AllgatherAlgo8::build () {
202+ auto self = std::make_shared<AllgatherAlgo8>();
203+ mscclpp::Algorithm allgatherAlgo (
204+ " default_allgather8" , " allgather" ,
205+ [self](std::shared_ptr<mscclpp::Communicator> comm,
206+ std::unordered_map<std::string, std::shared_ptr<void >>& extras) { self->initialize (comm, extras); },
207+ [self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void * input, void * output, size_t count, int dtype,
208+ cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void >>& extras) {
209+ return self->allgatherKernelFunc (ctx, input, output, count, static_cast <ncclDataType_t>(dtype), stream, extras);
210+ },
211+ [self](std::shared_ptr<mscclpp::Communicator> comm, const void * input, void * output, size_t count, int dtype) {
212+ return self->initAllgatherContext (comm, input, output, count, static_cast <ncclDataType_t>(dtype));
213+ },
214+ [self](const void * input, void * output, size_t count, int dtype) {
215+ return self->generateAllgatherContextKey (input, output, count, static_cast <ncclDataType_t>(dtype));
216+ });
217+ return allgatherAlgo;
218+ }
0 commit comments