Skip to content

Commit ed71b87

Browse files
authored
Bugfixes & perf improvements to all_to_all (#15)
1 parent d480977 commit ed71b87

17 files changed

+248
-123
lines changed

csrc/all_to_all/all_to_all.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ AllToAll::AllToAll(
2929
dpSize(dpSize),
3030
numSMs(get_sm_count()) {
3131

32-
ROSE_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
33-
ROSE_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");
32+
PPLX_ASSERT(hiddenDimBytes % 16 == 0, "invalid hidden dim bytes");
33+
PPLX_ASSERT(hiddenDimScaleBytes % 16 == 0, "invalid hidden dim scale bytes");
3434
const size_t perTokenBytes =
3535
round_up<size_t>(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16);
3636

37-
ROSE_ASSERT(numLocalExperts != 0, "numLocalExperts is 0");
38-
ROSE_ASSERT(numDPGroups > 1, "at least 2 DP groups are required");
39-
ROSE_ASSERT(hiddenDimScaleBytes <= hiddenDimBytes, "invalid hidden dim bytes");
37+
PPLX_ASSERT(numLocalExperts != 0, "numLocalExperts is 0");
38+
PPLX_ASSERT(numDPGroups > 1, "at least 2 DP groups are required");
39+
PPLX_ASSERT(hiddenDimScaleBytes <= hiddenDimBytes, "invalid hidden dim bytes");
4040
}
4141

4242
AllToAll::~AllToAll() {}

csrc/all_to_all/bench_all_to_all.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <nvshmemx.h>
1717
#include <nvtx3/nvToolsExt.h>
1818

19+
#include <array>
1920
#include <iomanip>
2021
#include <iostream>
2122

csrc/all_to_all/internode.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,36 @@ AllToAllInterNode::AllToAllInterNode(
3636
numTokensPerDP = mallocZeroBuffer<uint32_t>(numLocalExperts * numDPGroups);
3737

3838
numTokensBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * numLocalExperts * numDPGroups);
39-
ROSE_ASSERT(numTokensBuffer != nullptr, "failed to allocate numTokensBuffer");
39+
PPLX_ASSERT(numTokensBuffer != nullptr, "failed to allocate numTokensBuffer");
4040
cudaMemset(numTokensBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups);
4141

4242
numDispatchRecvBuffer =
4343
(uint64_t *)nvshmem_malloc(sizeof(uint64_t) * numLocalExperts * numDPGroups);
44-
ROSE_ASSERT(numDispatchRecvBuffer != nullptr, "failed to allocate numDispatchRecvBuffer");
44+
PPLX_ASSERT(numDispatchRecvBuffer != nullptr, "failed to allocate numDispatchRecvBuffer");
4545
cudaMemset(numDispatchRecvBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups);
4646

4747
combineSignalBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * maxNumTokens);
48-
ROSE_ASSERT(combineSignalBuffer != nullptr, "failed to allocate combineSignalBuffer");
48+
PPLX_ASSERT(combineSignalBuffer != nullptr, "failed to allocate combineSignalBuffer");
4949
cudaMemset(combineSignalBuffer, 0, sizeof(uint64_t) * maxNumTokens);
5050

5151
combineSyncBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * worldSize);
52-
ROSE_ASSERT(combineSyncBuffer != nullptr, "failed to allocate combineSyncBuffer");
52+
PPLX_ASSERT(combineSyncBuffer != nullptr, "failed to allocate combineSyncBuffer");
5353
cudaMemset(combineSyncBuffer, 0, sizeof(uint64_t) * worldSize);
5454

5555
// Buffers for dispatch.
5656
const size_t perTokenBytes =
5757
round_up<size_t>(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16);
5858
xDispatchIn = (std::byte *)nvshmem_malloc(maxNumTokens * perTokenBytes);
59-
ROSE_ASSERT(xDispatchIn != nullptr, "failed to allocate xDispatchIn");
59+
PPLX_ASSERT(xDispatchIn != nullptr, "failed to allocate xDispatchIn");
6060
xDispatchOut = (std::byte *)nvshmem_malloc(maxBatchTokens * perTokenBytes);
61-
ROSE_ASSERT(xDispatchOut != nullptr, "failed to allocate xDispatchOut");
61+
PPLX_ASSERT(xDispatchOut != nullptr, "failed to allocate xDispatchOut");
6262

6363
// Buffers for combine. The allocations are a bit wider to accommodate all
6464
// possible data types (primarily float for testing and bfloat16 for prod).
6565
xCombineIn = (std::byte *)nvshmem_malloc(maxBatchTokens * hiddenDim * sizeof(float));
66-
ROSE_ASSERT(xCombineIn != nullptr, "failed to allocate xCombineIn");
66+
PPLX_ASSERT(xCombineIn != nullptr, "failed to allocate xCombineIn");
6767
xCombineOut = (std::byte *)nvshmem_malloc(maxNumTokens * numExperts * hiddenDim * sizeof(float));
68-
ROSE_ASSERT(xCombineOut != nullptr, "failed to allocate xCombineOut");
68+
PPLX_ASSERT(xCombineOut != nullptr, "failed to allocate xCombineOut");
6969

7070
// Buffers for token tracking.
7171
sourceIndex = mallocZeroBuffer<uint32_t>(maxBatchTokens);

csrc/all_to_all/internode_combine.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void AllToAllInterNode::combine(
218218
));
219219
break;
220220
default:
221-
ROSE_UNREACHABLE("invalid split mode");
221+
PPLX_UNREACHABLE("invalid split mode");
222222
}
223223
nvtxRangePop();
224224
}

csrc/all_to_all/internode_dispatch.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <nvtx3/nvToolsExt.h>
55

66
#include "all_to_all/internode.h"
7-
#include "core/device_utils.h"
7+
#include "core/device_utils.cuh"
88
#include "core/utils.h"
99

1010
using namespace pplx;
@@ -58,16 +58,15 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel(
5858
const unsigned dpGroup = rank / dpSize;
5959
const unsigned dpRank = rank % dpSize;
6060
const unsigned tokenDim = hiddenDim + hiddenDimScale;
61-
const unsigned tokenStride =
62-
device::round_up<unsigned>(tokenDim + sizeof(uint32_t), sizeof(int4));
61+
const unsigned tokenStride = round_up<unsigned>(tokenDim + sizeof(uint32_t), sizeof(int4));
6362
const unsigned WARP_SIZE = 32;
6463
const unsigned warpId = threadIdx.x / WARP_SIZE;
6564
const unsigned laneId = threadIdx.x % WARP_SIZE;
6665

6766
// Determine the number of tokens populated which are to be sent.
6867
const unsigned numSendTokens = boundM ? __ldg(boundM) : m;
69-
ROSE_DEVICE_ASSERT(numSendTokens <= maxNumTokens);
70-
ROSE_DEVICE_ASSERT(
68+
PPLX_DEVICE_ASSERT(numSendTokens <= maxNumTokens);
69+
PPLX_DEVICE_ASSERT(
7170
hiddenDimScale == 0 || numSendTokens == 0 || (expertXScale != nullptr && dpXScale != nullptr)
7271
);
7372

@@ -170,14 +169,14 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel(
170169
}
171170

172171
if (DO_RECV) {
173-
__syncthreads();
172+
cooperative_groups::this_grid().sync();
174173
}
175174
}
176175

177176
if constexpr (DO_RECV) {
178177
// Wait for the token counts to be sent.
179178
const size_t numExpertsAndGroups = numLocalExperts * numDPGroups;
180-
const size_t expertsPerBlock = device::ceil_div<size_t>(numExpertsAndGroups, gridDim.x);
179+
const size_t expertsPerBlock = ceil_div<size_t>(numExpertsAndGroups, gridDim.x);
181180
uint32_t *sharedExpert = reinterpret_cast<uint32_t *>(sharedMemory);
182181
uint32_t *sharedToken = sharedExpert + expertsPerBlock;
183182

@@ -353,7 +352,7 @@ void AllToAllInterNode::dispatch(
353352
));
354353
break;
355354
default:
356-
ROSE_UNREACHABLE("invalid split mode");
355+
PPLX_UNREACHABLE("invalid split mode");
357356
}
358357
nvtxRangePop();
359358
}

csrc/all_to_all/intranode.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ AllToAllIntraNode::AllToAllIntraNode(
7676
}
7777

7878
auto dstHandlesHost = distributed->allToAll(srcHandlesHost);
79-
8079
for (unsigned i = 0; i < worldSize; i++) {
8180
auto &ptr = recvBuffers.emplace_back();
8281
if (i == rank) {
@@ -97,6 +96,31 @@ AllToAllIntraNode::AllToAllIntraNode(
9796
));
9897
}
9998

99+
// Allocate the local buffer for dispatch counts.
100+
CUDACHECK(cudaMalloc(&localRecvCountPtr, sizeof(uint32_t) * maxNumTokens));
101+
CUDACHECK(cudaMemset(localRecvCountPtr, 0, sizeof(uint32_t) * maxNumTokens));
102+
CUDACHECK(cudaMalloc(&countBuffersPtr, sizeof(uint32_t *) * worldSize));
103+
{
104+
cudaIpcMemHandle_t countHandle;
105+
CUDACHECK(cudaIpcGetMemHandle(&countHandle, localRecvCountPtr));
106+
auto countHandlesHost = distributed->allGather(countHandle);
107+
108+
countBuffers.resize(worldSize);
109+
for (unsigned i = 0; i < worldSize; i++) {
110+
if (i == rank) {
111+
countBuffers[i] = localRecvCountPtr;
112+
} else {
113+
CUDACHECK(cudaIpcOpenMemHandle(
114+
(void **)&countBuffers[i], countHandlesHost[i], cudaIpcMemLazyEnablePeerAccess
115+
));
116+
}
117+
}
118+
119+
CUDACHECK(cudaMemcpy(
120+
countBuffersPtr, countBuffers.data(), sizeof(uint32_t *) * worldSize, cudaMemcpyHostToDevice
121+
));
122+
}
123+
100124
// Allocate the local buffers.
101125
tokenCount = mallocZeroBuffer<uint32_t>(numExperts);
102126
numTokensPerRank = mallocZeroBuffer<uint32_t>(numLocalExperts * worldSize);
@@ -117,11 +141,15 @@ AllToAllIntraNode::~AllToAllIntraNode() {
117141
CUDACHECK(cudaFree(sendBuffers[i]));
118142
if (i != rank) {
119143
CUDACHECK(cudaIpcCloseMemHandle(recvBuffers[i]));
144+
CUDACHECK(cudaIpcCloseMemHandle(countBuffers[i]));
120145
}
121146
}
122147

123148
CUDACHECK(cudaFree(recvBuffersPtr));
124149
CUDACHECK(cudaFree(sendBuffersPtr));
150+
CUDACHECK(cudaFree(countBuffersPtr));
151+
CUDACHECK(cudaFree(localRecvCountPtr));
152+
125153
CUDACHECK(cudaFree(tokenCount));
126154
CUDACHECK(cudaFree(numTokensPerRank));
127155

csrc/all_to_all/intranode.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#pragma once
22

3-
#include "core/device_utils.h"
3+
#include "core/device_utils.cuh"
44

55
#include <cstdint>
66

@@ -52,7 +52,7 @@ private:
5252

5353
__device__ __forceinline__ std::byte *getBaseTokenPtr(unsigned rank) {
5454
return getBaseCounterPtr(rank) +
55-
device::round_up<size_t>(numLocalExperts * sizeof(uint32_t), sizeof(int4));
55+
round_up<size_t>(numLocalExperts * sizeof(uint32_t), sizeof(int4));
5656
}
5757

5858
private:

csrc/all_to_all/intranode.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,26 @@ class AllToAllIntraNode final : public AllToAll {
6666
/// @section Peer-to-Peer shared buffers.
6767
std::vector<std::byte *> sendBuffers;
6868
std::byte **sendBuffersPtr;
69-
7069
std::vector<std::byte *> recvBuffers;
7170
std::byte **recvBuffersPtr;
7271

72+
/// Buffer to synchronize multiple senders with a receiver in dispatch.
73+
uint32_t *localRecvCountPtr;
74+
std::vector<uint32_t *> countBuffers;
75+
uint32_t **countBuffersPtr;
76+
7377
/// @section Global buffers for use within kernels.
74-
uint32_t *numTokensPerRank = nullptr;
75-
uint32_t *tokenCount = nullptr;
78+
uint32_t *numTokensPerRank;
79+
uint32_t *tokenCount;
7680

7781
/// @section Internal buffers communicating between dispatch and combine.
78-
uint32_t *sourceIndex = nullptr;
79-
uint32_t *sourceExpert = nullptr;
80-
uint32_t *sourceOffset = nullptr;
81-
uint32_t *sourceRank = nullptr;
82-
uint32_t *sourceToken = nullptr;
83-
uint32_t *sourceRoute = nullptr;
84-
uint32_t *tokenIndex = nullptr;
82+
uint32_t *sourceIndex;
83+
uint32_t *sourceExpert;
84+
uint32_t *sourceOffset;
85+
uint32_t *sourceRank;
86+
uint32_t *sourceToken;
87+
uint32_t *sourceRoute;
88+
uint32_t *tokenIndex;
8589
};
8690

8791
} // namespace pplx

0 commit comments

Comments
 (0)