Skip to content

Commit 70b8297

Browse files
authored
Revise NCCL API implementation (#617)
- Make nccl interface extensible. Customer can register their own algo to NCCL API. User can provide customized algo selection function. - Fallback to NCCL/RCCL if no algo is selected based on algo selection function - MSCCLPP interfaces now works for any scale
1 parent 5ac4276 commit 70b8297

File tree

20 files changed

+2117
-801
lines changed

20 files changed

+2117
-801
lines changed

apps/nccl/include/mscclpp/nccl.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#define NCCL_H_
77

88
#include <mscclpp/gpu.hpp>
9+
#include <mscclpp/gpu_data_types.hpp>
910

1011
#ifdef __cplusplus
1112
extern "C" {
@@ -260,6 +261,38 @@ typedef enum {
260261
#endif
261262
} ncclDataType_t;
262263

264+
static inline size_t ncclTypeSize(ncclDataType_t type) {
265+
switch (type) {
266+
case ncclInt8:
267+
case ncclUint8:
268+
return 1;
269+
case ncclFloat16:
270+
return 2;
271+
case ncclInt32:
272+
case ncclUint32:
273+
return 4;
274+
case ncclInt64:
275+
case ncclUint64:
276+
return 8;
277+
case ncclFloat32:
278+
return 4;
279+
case ncclFloat64:
280+
return 8;
281+
#if defined(__CUDA_BF16_TYPES_EXIST__)
282+
case ncclBfloat16:
283+
return 2;
284+
#endif // defined(__CUDA_BF16_TYPES_EXIST__)
285+
#if defined(__CUDA_FP8_TYPES_EXIST__)
286+
case ncclFp8E4M3:
287+
case ncclFp8E5M2:
288+
return 1;
289+
#endif // defined(__CUDA_FP8_TYPES_EXIST__)
290+
case ncclNumTypes:
291+
return 0;
292+
}
293+
return 0;
294+
}
295+
263296
/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */
264297
typedef enum {
265298
/* ncclScalarDevice: The scalar is in device-visible memory and will be

apps/nccl/src/allgather.cu

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include <mscclpp/nccl.h>
5+
6+
#include <mscclpp/algorithm.hpp>
7+
#include <mscclpp/env.hpp>
8+
#include <mscclpp/gpu_utils.hpp>
9+
10+
#include "allgather.hpp"
11+
#include "debug.h"
12+
13+
AllgatherAlgo6::AllgatherAlgo6() : disableChannelCache_(false) {
14+
if (mscclpp::env()->disableChannelCache) {
15+
disableChannelCache_ = true;
16+
}
17+
}
18+
19+
void AllgatherAlgo6::initialize(std::shared_ptr<mscclpp::Communicator> comm,
20+
std::unordered_map<std::string, std::shared_ptr<void>>&) {
21+
this->conns_ = setupConnections(comm);
22+
}
23+
24+
ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
25+
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
26+
std::unordered_map<std::string, std::shared_ptr<void>>&) {
27+
int nBlocks = 28;
28+
const size_t bytes = count * ncclTypeSize(dtype);
29+
const size_t nElem = bytes / sizeof(int);
30+
int rank = ctx->rank;
31+
if (bytes <= 32 * (1 << 20)) {
32+
if (nElem <= 4096) {
33+
nBlocks = 7;
34+
} else if (nElem <= 32768) {
35+
nBlocks = 14;
36+
} else if (nElem >= 2097152) {
37+
nBlocks = 35;
38+
}
39+
} else {
40+
nBlocks = 35;
41+
}
42+
43+
size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
44+
if ((char*)input == (char*)output + rank * bytes) {
45+
allgather6<false><<<nBlocks, 1024, 0, stream>>>((void*)input, ctx->memoryChannelDeviceHandles.get(),
46+
channelOutOffset, ctx->rank, ctx->workSize, ctx->nRanksPerNode,
47+
nElem);
48+
} else {
49+
allgather6<true><<<nBlocks, 1024, 0, stream>>>((void*)input, ctx->memoryChannelDeviceHandles.get(),
50+
channelOutOffset, ctx->rank, ctx->workSize, ctx->nRanksPerNode,
51+
nElem);
52+
}
53+
cudaError_t err = cudaGetLastError();
54+
if (err != cudaSuccess) {
55+
WARN("AllgatherAlgo6 failed with error %d", err);
56+
return ncclInternalError;
57+
}
58+
return ncclSuccess;
59+
}
60+
61+
std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
62+
const void*, void* output, size_t count,
63+
ncclDataType_t dtype) {
64+
constexpr int nChannelsPerConnection = 35;
65+
66+
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
67+
ctx->rank = comm->bootstrap()->getRank();
68+
ctx->workSize = comm->bootstrap()->getNranks();
69+
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
70+
71+
// setup semaphores
72+
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));
73+
74+
size_t bytes = count * ncclTypeSize(dtype);
75+
size_t recvBytes;
76+
CUdeviceptr recvBasePtr;
77+
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
78+
size_t channelOutOffset = (char*)output - (char*)recvBasePtr;
79+
if (disableChannelCache_) {
80+
channelOutOffset = 0;
81+
recvBytes = bytes;
82+
recvBasePtr = (CUdeviceptr)output;
83+
}
84+
ctx->extras.insert({"channel_out_offset", std::make_shared<size_t>(channelOutOffset)});
85+
86+
// register the memory for the broadcast operation
87+
mscclpp::RegisteredMemory localMemory =
88+
comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
89+
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm, ctx->rank, localMemory);
90+
ctx->memoryChannels = std::move(
91+
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection));
92+
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
93+
94+
// keep registered memories reference
95+
ctx->registeredMemories = std::move(remoteMemories);
96+
ctx->registeredMemories.push_back(localMemory);
97+
98+
return ctx;
99+
}
100+
101+
mscclpp::AlgorithmCtxKey AllgatherAlgo6::generateAllgatherContextKey(const void*, void* output, size_t,
102+
ncclDataType_t) {
103+
static int tag = 0;
104+
if (disableChannelCache_) {
105+
// always return a new key if channel cache is disabled
106+
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, tag++};
107+
}
108+
size_t recvBytes;
109+
CUdeviceptr recvBasePtr;
110+
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
111+
return mscclpp::AlgorithmCtxKey{nullptr, (void*)recvBasePtr, 0, recvBytes, 0};
112+
}
113+
114+
mscclpp::Algorithm AllgatherAlgo6::build() {
115+
auto self = std::make_shared<AllgatherAlgo6>();
116+
mscclpp::Algorithm allgatherAlgo(
117+
"default_allgather6", "allgather",
118+
[self](std::shared_ptr<mscclpp::Communicator> comm,
119+
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
120+
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
121+
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
122+
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
123+
},
124+
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
125+
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
126+
},
127+
[self](const void* input, void* output, size_t count, int dtype) {
128+
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
129+
});
130+
return allgatherAlgo;
131+
}
132+
133+
void AllgatherAlgo8::initialize(std::shared_ptr<mscclpp::Communicator> comm,
134+
std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
135+
this->conns_ = setupConnections(comm);
136+
this->scratchBuffer_ = std::static_pointer_cast<char>(extras.at("scratch"));
137+
this->scratchBufferSize_ = *(size_t*)(extras.at("scratch_size").get());
138+
}
139+
140+
ncclResult_t AllgatherAlgo8::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
141+
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
142+
std::unordered_map<std::string, std::shared_ptr<void>>&) {
143+
int rank = ctx->rank;
144+
const size_t bytes = count * ncclTypeSize(dtype);
145+
const size_t nElem = bytes / sizeof(int);
146+
if ((char*)input == (char*)output + rank * bytes) {
147+
allgather8<false><<<56, 1024, 0, stream>>>((void*)input, this->scratchBuffer_.get(), (void*)output,
148+
ctx->memoryChannelDeviceHandles.get(), rank, ctx->nRanksPerNode,
149+
ctx->workSize, nElem);
150+
} else {
151+
allgather8<true><<<56, 1024, 0, stream>>>((void*)input, this->scratchBuffer_.get(), (void*)output,
152+
ctx->memoryChannelDeviceHandles.get(), rank, ctx->nRanksPerNode,
153+
ctx->workSize, nElem);
154+
}
155+
cudaError_t err = cudaGetLastError();
156+
if (err != cudaSuccess) {
157+
WARN("AllgatherAlgo8 failed with error %d", err);
158+
return ncclInternalError;
159+
}
160+
return ncclSuccess;
161+
}
162+
163+
std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
164+
const void* input, void*, size_t count,
165+
ncclDataType_t dtype) {
166+
constexpr int nChannelsPerConnection = 56;
167+
168+
auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
169+
ctx->rank = comm->bootstrap()->getRank();
170+
ctx->workSize = comm->bootstrap()->getNranks();
171+
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();
172+
173+
// setup semaphores
174+
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));
175+
176+
size_t bytes = count * ncclTypeSize(dtype);
177+
// register the memory for the broadcast operation
178+
mscclpp::RegisteredMemory localMemory = comm->registerMemory((void*)input, bytes, mscclpp::Transport::CudaIpc);
179+
mscclpp::RegisteredMemory scratchMemory =
180+
comm->registerMemory(this->scratchBuffer_.get(), scratchBufferSize_, mscclpp::Transport::CudaIpc);
181+
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm, ctx->rank, scratchMemory);
182+
183+
// setup channels
184+
ctx->memoryChannels = std::move(
185+
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection));
186+
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);
187+
188+
// keep registered memories reference
189+
ctx->registeredMemories = std::move(remoteMemories);
190+
ctx->registeredMemories.push_back(localMemory);
191+
ctx->registeredMemories.push_back(scratchMemory);
192+
193+
return ctx;
194+
}
195+
196+
mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t) {
197+
// always return same key, non-zero copy algo
198+
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
199+
}
200+
201+
mscclpp::Algorithm AllgatherAlgo8::build() {
202+
auto self = std::make_shared<AllgatherAlgo8>();
203+
mscclpp::Algorithm allgatherAlgo(
204+
"default_allgather8", "allgather",
205+
[self](std::shared_ptr<mscclpp::Communicator> comm,
206+
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
207+
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
208+
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
209+
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
210+
},
211+
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
212+
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
213+
},
214+
[self](const void* input, void* output, size_t count, int dtype) {
215+
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
216+
});
217+
return allgatherAlgo;
218+
}

apps/nccl/src/allgather.hpp

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#ifndef ALLGATHER_HPP_
55
#define ALLGATHER_HPP_
66

7+
#include <mscclpp/nccl.h>
8+
9+
#include <mscclpp/algorithm.hpp>
710
#include <mscclpp/concurrency_device.hpp>
811
#include <mscclpp/core.hpp>
912
#include <mscclpp/gpu.hpp>
@@ -206,34 +209,44 @@ __global__ void __launch_bounds__(1024, 1)
206209
}
207210
}
208211

209-
template <bool IsOutOfPlace, typename T>
210-
cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff,
211-
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels, size_t channelOutOffset, int rank,
212-
int nRanksPerNode, [[maybe_unused]] int worldSize, size_t nelems, cudaStream_t stream) {
213-
int nBlocks = 28;
214-
if (nelems * sizeof(T) <= 32 * (1 << 20)) {
215-
if (nelems <= 4096) {
216-
nBlocks = 7;
217-
} else if (nelems <= 32768) {
218-
nBlocks = 14;
219-
} else if (nelems >= 2097152) {
220-
nBlocks = 35;
221-
}
222-
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, memoryChannels, channelOutOffset, rank,
223-
worldSize, nRanksPerNode, nelems * sizeof(T) / sizeof(int));
224-
} else {
225-
#if defined(__HIP_PLATFORM_AMD__)
226-
nBlocks = 35;
227-
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, memoryChannels, channelOutOffset, rank,
228-
worldSize, nRanksPerNode, nelems * sizeof(T) / sizeof(int));
229-
#else
230-
nBlocks = 56;
231-
allgather8<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff,
232-
memoryChannels, rank, nRanksPerNode, worldSize,
233-
nelems * sizeof(T) / sizeof(int));
234-
#endif
235-
}
236-
return cudaGetLastError();
237-
}
212+
class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
213+
public:
214+
AllgatherAlgo6();
215+
mscclpp::Algorithm build() override;
216+
217+
private:
218+
bool disableChannelCache_;
219+
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
220+
221+
void initialize(std::shared_ptr<mscclpp::Communicator> comm, std::unordered_map<std::string, std::shared_ptr<void>>&);
222+
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
223+
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
224+
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
225+
226+
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
227+
void* output, size_t, ncclDataType_t);
228+
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);
229+
};
230+
231+
class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
232+
public:
233+
mscclpp::Algorithm build() override;
234+
235+
private:
236+
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;
237+
238+
void initialize(std::shared_ptr<mscclpp::Communicator> comm,
239+
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
240+
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
241+
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
242+
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
243+
244+
std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
245+
void* output, size_t, ncclDataType_t);
246+
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);
247+
248+
size_t scratchBufferSize_;
249+
std::shared_ptr<char> scratchBuffer_;
250+
};
238251

239252
#endif // ALLGATHER_HPP_

0 commit comments

Comments
 (0)