Skip to content
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
a91d273
WIP
Binyang2014 Aug 18, 2025
2fb4d74
WIP
Binyang2014 Aug 18, 2025
cc96414
WIP
Binyang2014 Aug 19, 2025
d561a5c
WIP
Binyang2014 Aug 19, 2025
0dfc13a
WIP
Binyang2014 Aug 19, 2025
11b3bdb
WIP
Binyang2014 Aug 19, 2025
a9e98ec
WIP
Binyang2014 Aug 20, 2025
f89ffe4
WIP
Binyang2014 Aug 21, 2025
82211f6
WIP
Binyang2014 Aug 21, 2025
978a182
WIP
Binyang2014 Aug 21, 2025
b89697d
WIP
Binyang2014 Aug 22, 2025
0f7ce20
Fix correctness
Binyang2014 Aug 22, 2025
e620110
WIP
Binyang2014 Aug 22, 2025
9df8ca3
clean code
Binyang2014 Aug 22, 2025
c922620
all works
Binyang2014 Aug 23, 2025
b5a8793
WIP
Binyang2014 Aug 23, 2025
f1a905f
Merge branch 'main' into binyli/nccl-algo
Binyang2014 Aug 23, 2025
1235a90
clean up
Binyang2014 Aug 23, 2025
030a4c2
WIP
Binyang2014 Aug 23, 2025
7bcd613
bug fix
Binyang2014 Aug 24, 2025
f1cde8b
fix compile error
Binyang2014 Aug 24, 2025
775f48e
WIP
Binyang2014 Aug 24, 2025
ca638b5
for amd
Binyang2014 Aug 24, 2025
646656b
WIP
Binyang2014 Aug 24, 2025
25f4ee2
WIP
Binyang2014 Aug 24, 2025
52211ce
WIP
Binyang2014 Aug 24, 2025
ffae384
add logs
Binyang2014 Aug 24, 2025
7be42b9
merge main
Binyang2014 Aug 25, 2025
f5b6f48
Merge branch 'main' into binyli/nccl-algo
Binyang2014 Aug 25, 2025
2313197
remove nccl.h from algorithm file
Binyang2014 Aug 26, 2025
1d6aee0
move algo to core lib
Binyang2014 Aug 26, 2025
3d3842b
WIP
Binyang2014 Aug 27, 2025
8e57b6b
WIP
Binyang2014 Aug 28, 2025
bea70b0
WIP
Binyang2014 Aug 28, 2025
b025bd3
WIP
Binyang2014 Aug 29, 2025
cd28260
WIP
Binyang2014 Aug 29, 2025
e2d33fb
example works
Binyang2014 Aug 29, 2025
d2941f1
add doc
Binyang2014 Aug 29, 2025
d1e636a
WIp
Binyang2014 Aug 29, 2025
33c6eb5
WIP
Binyang2014 Sep 2, 2025
44196ec
update doc
Binyang2014 Sep 2, 2025
0cbe6d7
update
Binyang2014 Sep 3, 2025
04e53c8
refactor
Binyang2014 Sep 9, 2025
4f91a50
merge main
Binyang2014 Sep 9, 2025
6583661
WIP
Binyang2014 Sep 9, 2025
351b5b0
WIP
Binyang2014 Sep 9, 2025
bc468b9
Merge branch 'main' into binyli/nccl-algo
Binyang2014 Sep 15, 2025
fde6120
update
Binyang2014 Sep 15, 2025
fb7405b
Merge branch 'main' into binyli/nccl-algo
Binyang2014 Sep 22, 2025
11b1364
merge main
Binyang2014 Sep 25, 2025
242ab3a
Address comments
Binyang2014 Sep 25, 2025
ae8a9d0
WIP
Binyang2014 Sep 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions apps/nccl/include/mscclpp/nccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#define NCCL_H_

#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_data_types.hpp>

#ifdef __cplusplus
extern "C" {
Expand Down Expand Up @@ -260,6 +261,38 @@ typedef enum {
#endif
} ncclDataType_t;

static inline size_t ncclTypeSize(ncclDataType_t type) {
switch (type) {
case ncclInt8:
case ncclUint8:
return 1;
case ncclFloat16:
return 2;
case ncclInt32:
case ncclUint32:
return 4;
case ncclInt64:
case ncclUint64:
return 8;
case ncclFloat32:
return 4;
case ncclFloat64:
return 8;
#if defined(__CUDA_BF16_TYPES_EXIST__)
case ncclBfloat16:
return 2;
#endif // defined(__CUDA_BF16_TYPES_EXIST__)
#if defined(__CUDA_FP8_TYPES_EXIST__)
case ncclFp8E4M3:
case ncclFp8E5M2:
return 1;
#endif // defined(__CUDA_FP8_TYPES_EXIST__)
case ncclNumTypes:
return 0;
}
return 0;
}

/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */
typedef enum {
/* ncclScalarDevice: The scalar is in device-visible memory and will be
Expand Down
218 changes: 218 additions & 0 deletions apps/nccl/src/allgather.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <mscclpp/nccl.h>

#include <mscclpp/algorithm.hpp>
#include <mscclpp/env.hpp>
#include <mscclpp/gpu_utils.hpp>

#include "allgather.hpp"
#include "debug.h"

AllgatherAlgo6::AllgatherAlgo6() : disableChannelCache_(false) {
if (mscclpp::env()->disableChannelCache) {
disableChannelCache_ = true;
}
}

void AllgatherAlgo6::initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>&) {
this->conns_ = setupConnections(comm);
}

ncclResult_t AllgatherAlgo6::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>&) {
int nBlocks = 28;
const size_t bytes = count * ncclTypeSize(dtype);
const size_t nElem = bytes / sizeof(int);
int rank = ctx->rank;
if (bytes <= 32 * (1 << 20)) {
if (nElem <= 4096) {
nBlocks = 7;
} else if (nElem <= 32768) {
nBlocks = 14;
} else if (nElem >= 2097152) {
nBlocks = 35;
}
} else {
nBlocks = 35;
}

size_t channelOutOffset = *static_cast<size_t*>(ctx->extras["channel_out_offset"].get());
if ((char*)input == (char*)output + rank * bytes) {
allgather6<false><<<nBlocks, 1024, 0, stream>>>((void*)input, ctx->memoryChannelDeviceHandles.get(),
channelOutOffset, ctx->rank, ctx->workSize, ctx->nRanksPerNode,
nElem);
} else {
allgather6<true><<<nBlocks, 1024, 0, stream>>>((void*)input, ctx->memoryChannelDeviceHandles.get(),
channelOutOffset, ctx->rank, ctx->workSize, ctx->nRanksPerNode,
nElem);
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
WARN("AllgatherAlgo6 failed with error %d", err);
return ncclInternalError;
}
return ncclSuccess;
}

std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo6::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
const void*, void* output, size_t count,
ncclDataType_t dtype) {
constexpr int nChannelsPerConnection = 35;

auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();

// setup semaphores
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));

size_t bytes = count * ncclTypeSize(dtype);
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
size_t channelOutOffset = (char*)output - (char*)recvBasePtr;
if (disableChannelCache_) {
channelOutOffset = 0;
recvBytes = bytes;
recvBasePtr = (CUdeviceptr)output;
}
ctx->extras.insert({"channel_out_offset", std::make_shared<size_t>(channelOutOffset)});

// register the memory for the broadcast operation
mscclpp::RegisteredMemory localMemory =
comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm, ctx->rank, localMemory);
ctx->memoryChannels = std::move(
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection));
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);

// keep registered memories reference
ctx->registeredMemories = std::move(remoteMemories);
ctx->registeredMemories.push_back(localMemory);

return ctx;
}

mscclpp::AlgorithmCtxKey AllgatherAlgo6::generateAllgatherContextKey(const void*, void* output, size_t,
ncclDataType_t) {
static int tag = 0;
if (disableChannelCache_) {
// always return a new key if channel cache is disabled
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, tag++};
}
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output));
return mscclpp::AlgorithmCtxKey{nullptr, (void*)recvBasePtr, 0, recvBytes, 0};
}

mscclpp::Algorithm AllgatherAlgo6::build() {
auto self = std::make_shared<AllgatherAlgo6>();
mscclpp::Algorithm allgatherAlgo(
"default_allgather6", "allgather",
[self](std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
},
[self](const void* input, void* output, size_t count, int dtype) {
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
});
return allgatherAlgo;
}

void AllgatherAlgo8::initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
this->conns_ = setupConnections(comm);
this->scratchBuffer_ = std::static_pointer_cast<char>(extras.at("scratch"));
this->scratchBufferSize_ = *(size_t*)(extras.at("scratch_size").get());
}

ncclResult_t AllgatherAlgo8::allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input,
void* output, size_t count, ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>&) {
int rank = ctx->rank;
const size_t bytes = count * ncclTypeSize(dtype);
const size_t nElem = bytes / sizeof(int);
if ((char*)input == (char*)output + rank * bytes) {
allgather8<false><<<56, 1024, 0, stream>>>((void*)input, this->scratchBuffer_.get(), (void*)output,
ctx->memoryChannelDeviceHandles.get(), rank, ctx->nRanksPerNode,
ctx->workSize, nElem);
} else {
allgather8<true><<<56, 1024, 0, stream>>>((void*)input, this->scratchBuffer_.get(), (void*)output,
ctx->memoryChannelDeviceHandles.get(), rank, ctx->nRanksPerNode,
ctx->workSize, nElem);
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
WARN("AllgatherAlgo8 failed with error %d", err);
return ncclInternalError;
}
return ncclSuccess;
}

std::shared_ptr<mscclpp::AlgorithmCtx> AllgatherAlgo8::initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm,
const void* input, void*, size_t count,
ncclDataType_t dtype) {
constexpr int nChannelsPerConnection = 56;

auto ctx = std::make_shared<mscclpp::AlgorithmCtx>();
ctx->rank = comm->bootstrap()->getRank();
ctx->workSize = comm->bootstrap()->getNranks();
ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode();

// setup semaphores
ctx->memorySemaphores = std::move(setupMemorySemaphores(comm, this->conns_, nChannelsPerConnection));

size_t bytes = count * ncclTypeSize(dtype);
// register the memory for the broadcast operation
mscclpp::RegisteredMemory localMemory = comm->registerMemory((void*)input, bytes, mscclpp::Transport::CudaIpc);
mscclpp::RegisteredMemory scratchMemory =
comm->registerMemory(this->scratchBuffer_.get(), scratchBufferSize_, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm, ctx->rank, scratchMemory);

// setup channels
ctx->memoryChannels = std::move(
setupMemoryChannels(this->conns_, ctx->memorySemaphores, remoteMemories, localMemory, nChannelsPerConnection));
ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels);

// keep registered memories reference
ctx->registeredMemories = std::move(remoteMemories);
ctx->registeredMemories.push_back(localMemory);
ctx->registeredMemories.push_back(scratchMemory);

return ctx;
}

mscclpp::AlgorithmCtxKey AllgatherAlgo8::generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t) {
// always return same key, non-zero copy algo
return mscclpp::AlgorithmCtxKey{nullptr, nullptr, 0, 0, 0};
}

mscclpp::Algorithm AllgatherAlgo8::build() {
auto self = std::make_shared<AllgatherAlgo8>();
mscclpp::Algorithm allgatherAlgo(
"default_allgather8", "allgather",
[self](std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras) { self->initialize(comm, extras); },
[self](const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output, size_t count, int dtype,
cudaStream_t stream, std::unordered_map<std::string, std::shared_ptr<void>>& extras) {
return self->allgatherKernelFunc(ctx, input, output, count, static_cast<ncclDataType_t>(dtype), stream, extras);
},
[self](std::shared_ptr<mscclpp::Communicator> comm, const void* input, void* output, size_t count, int dtype) {
return self->initAllgatherContext(comm, input, output, count, static_cast<ncclDataType_t>(dtype));
},
[self](const void* input, void* output, size_t count, int dtype) {
return self->generateAllgatherContextKey(input, output, count, static_cast<ncclDataType_t>(dtype));
});
return allgatherAlgo;
}
71 changes: 42 additions & 29 deletions apps/nccl/src/allgather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
#ifndef ALLGATHER_HPP_
#define ALLGATHER_HPP_

#include <mscclpp/nccl.h>

#include <mscclpp/algorithm.hpp>
#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
Expand Down Expand Up @@ -206,34 +209,44 @@ __global__ void __launch_bounds__(1024, 1)
}
}

template <bool IsOutOfPlace, typename T>
cudaError_t allgather(T* buff, [[maybe_unused]] T* scratch, [[maybe_unused]] T* resultBuff,
mscclpp::DeviceHandle<mscclpp::MemoryChannel>* memoryChannels, size_t channelOutOffset, int rank,
int nRanksPerNode, [[maybe_unused]] int worldSize, size_t nelems, cudaStream_t stream) {
int nBlocks = 28;
if (nelems * sizeof(T) <= 32 * (1 << 20)) {
if (nelems <= 4096) {
nBlocks = 7;
} else if (nelems <= 32768) {
nBlocks = 14;
} else if (nelems >= 2097152) {
nBlocks = 35;
}
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, memoryChannels, channelOutOffset, rank,
worldSize, nRanksPerNode, nelems * sizeof(T) / sizeof(int));
} else {
#if defined(__HIP_PLATFORM_AMD__)
nBlocks = 35;
allgather6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, memoryChannels, channelOutOffset, rank,
worldSize, nRanksPerNode, nelems * sizeof(T) / sizeof(int));
#else
nBlocks = 56;
allgather8<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff,
memoryChannels, rank, nRanksPerNode, worldSize,
nelems * sizeof(T) / sizeof(int));
#endif
}
return cudaGetLastError();
}
class AllgatherAlgo6 : public mscclpp::AlgorithmBuilder {
public:
AllgatherAlgo6();
mscclpp::Algorithm build() override;

private:
bool disableChannelCache_;
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;

void initialize(std::shared_ptr<mscclpp::Communicator> comm, std::unordered_map<std::string, std::shared_ptr<void>>&);
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);

std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
void* output, size_t, ncclDataType_t);
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);
};

class AllgatherAlgo8 : public mscclpp::AlgorithmBuilder {
public:
mscclpp::Algorithm build() override;

private:
std::vector<std::shared_ptr<mscclpp::Connection>> conns_;

void initialize(std::shared_ptr<mscclpp::Communicator> comm,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);
ncclResult_t allgatherKernelFunc(const std::shared_ptr<mscclpp::AlgorithmCtx> ctx, const void* input, void* output,
size_t count, [[maybe_unused]] ncclDataType_t dtype, cudaStream_t stream,
std::unordered_map<std::string, std::shared_ptr<void>>& extras);

std::shared_ptr<mscclpp::AlgorithmCtx> initAllgatherContext(std::shared_ptr<mscclpp::Communicator> comm, const void*,
void* output, size_t, ncclDataType_t);
mscclpp::AlgorithmCtxKey generateAllgatherContextKey(const void*, void*, size_t, ncclDataType_t);

size_t scratchBufferSize_;
std::shared_ptr<char> scratchBuffer_;
};

#endif // ALLGATHER_HPP_
Loading
Loading