Skip to content

Commit 514656b

Browse files
committed
rename scatter/gather to dispatch/combine
1 parent 7af3059 commit 514656b

18 files changed

+201
-148
lines changed

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
build-cmake
22
build
3-
pplx_kernels/*.so
3+
*.so
44
*.egg-info
55
*.pyc
66
data
7+
dist
8+
.ruff_cache
9+
.mypy_cache
10+
.pytest_cache
11+
__pycache__

csrc/all_to_all/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
add_library(all_to_all_lib STATIC
44
all_to_all.cpp
5-
internode_scatter.cu
6-
internode_gather.cu
5+
internode_dispatch.cu
6+
internode_combine.cu
77
internode.cpp
88
)
99
target_link_libraries(all_to_all_lib PUBLIC

csrc/all_to_all/all_to_all.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class AllToAll {
6666
/// The maximum number of tokens in a batch.
6767
const size_t maxBatchTokens;
6868

69-
/// @section Internal buffers communicating between scatter and gather.
69+
/// @section Internal buffers communicating between dispatch and combine.
7070
uint32_t *numTokensPerDP = nullptr;
7171
uint32_t *sourceIndex = nullptr;
7272
uint32_t *sourceExpert = nullptr;

csrc/all_to_all/bench_all_to_all.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// All-to-all scatter benchmark
1+
// All-to-all benchmark
22

33
#include "all_to_all/internode.h"
44
#include "all_to_all/test_utils.h"
@@ -100,14 +100,14 @@ std::pair<Time, Time> benchmark(
100100
// Warmup
101101
auto run = [&]() -> std::pair<float, float> {
102102
nvshmemx_barrier_all_on_stream(stream);
103-
// Scatter.
103+
// Dispatch.
104104
for (size_t i = 0; i < numSamples; i++) {
105105
nvshmemx_barrier_all_on_stream(stream);
106106
CUDACHECK(cudaStreamSynchronize(stream));
107107

108108
CUDACHECK(cudaEventRecord(std::get<0>(events[i]), stream));
109109

110-
allToAll.scatter(
110+
allToAll.dispatch(
111111
Strided1D<int32_t>(outTokensPerExpertDevice, 1),
112112
Strided2D<std::byte>(
113113
outExpertDevice, hiddenDimBytes, hiddenDimBytes * config.numTokens * numPEs
@@ -128,7 +128,7 @@ std::pair<Time, Time> benchmark(
128128

129129
CUDACHECK(cudaEventRecord(std::get<1>(events[i]), stream));
130130

131-
allToAll.gather<T>(
131+
allToAll.combine<T>(
132132
Strided1D<nv_bfloat16>(outTokensDevice, config.hiddenDim),
133133
Strided2D<uint32_t>(indicesDevice, 1, config.expertsPerToken),
134134
Strided2D<float>(weightsDevice, 1, config.expertsPerToken),
@@ -145,15 +145,15 @@ std::pair<Time, Time> benchmark(
145145
}
146146

147147
CUDACHECK(cudaStreamSynchronize(stream));
148-
float totalScatterMs = 0.0f, totalGatherMs = 0.0f;
148+
float totalDispatchMs = 0.0f, totalCombineMs = 0.0f;
149149
for (size_t i = 0; i < numSamples; i++) {
150-
float scatterMs = 0.0f, gatherMs = 0.0f;
151-
CUDACHECK(cudaEventElapsedTime(&scatterMs, std::get<0>(events[i]), std::get<1>(events[i])));
152-
CUDACHECK(cudaEventElapsedTime(&gatherMs, std::get<1>(events[i]), std::get<2>(events[i])));
153-
totalScatterMs += scatterMs;
154-
totalGatherMs += gatherMs;
150+
float dispatchMs = 0.0f, combineMs = 0.0f;
151+
CUDACHECK(cudaEventElapsedTime(&dispatchMs, std::get<0>(events[i]), std::get<1>(events[i])));
152+
CUDACHECK(cudaEventElapsedTime(&combineMs, std::get<1>(events[i]), std::get<2>(events[i])));
153+
totalDispatchMs += dispatchMs;
154+
totalCombineMs += combineMs;
155155
}
156-
return {totalScatterMs / numSamples, totalGatherMs / numSamples};
156+
return {totalDispatchMs / numSamples, totalCombineMs / numSamples};
157157
};
158158

159159
MPI_Barrier(MPI_COMM_WORLD);
@@ -165,15 +165,15 @@ std::pair<Time, Time> benchmark(
165165

166166
MPI_Barrier(MPI_COMM_WORLD);
167167
nvtxRangePush("benchmark");
168-
std::vector<float> scatterTimeUs, gatherTimeUs;
168+
std::vector<float> dispatchTimeUs, combineTimeUs;
169169
for (int i = 0; i < repeat; i++) {
170-
auto [scatterTimeMs, gatherTimeMs] = run();
171-
scatterTimeUs.push_back(scatterTimeMs * 1000);
172-
gatherTimeUs.push_back(gatherTimeMs * 1000);
170+
auto [dispatchTimeMs, combineTimeMs] = run();
171+
dispatchTimeUs.push_back(dispatchTimeMs * 1000);
172+
combineTimeUs.push_back(combineTimeMs * 1000);
173173
}
174174
nvtxRangePop();
175175

176-
return {average(scatterTimeUs), average(gatherTimeUs)};
176+
return {average(dispatchTimeUs), average(combineTimeUs)};
177177
}
178178

179179
} // namespace
@@ -240,15 +240,16 @@ int main(int argc, char **argv) {
240240
};
241241

242242
for (const auto &config : configs) {
243-
auto [scatter, gather] = benchmark<nv_bfloat16>(10, config, currentPE, numPEs, stream);
243+
auto [dispatch, combine] = benchmark<nv_bfloat16>(10, config, currentPE, numPEs, stream);
244244
if (currentPE == 0) {
245-
auto [scatterMean, scatterStddev] = scatter;
246-
auto [gatherMean, gatherStddev] = gather;
245+
auto [dispatchMean, dispatchStddev] = dispatch;
246+
auto [combineMean, combineStddev] = combine;
247247
std::cout << std::setw(3) << config.numTokens << " " << std::setw(3) << config.numExperts
248248
<< " " << std::setw(3) << config.expertsPerToken << " " << std::setw(4)
249249
<< config.hiddenDim << " " << std::fixed << std::setprecision(3)
250-
<< "Scatter: " << std::setw(10) << scatterMean << "us ± " << scatterStddev << "us "
251-
<< "Gather: " << std::setw(10) << gatherMean << "us ± " << gatherStddev << "us"
250+
<< "Dispatch: " << std::setw(10) << dispatchMean << "us ± " << dispatchStddev
251+
<< "us "
252+
<< "Combine: " << std::setw(10) << combineMean << "us ± " << combineStddev << "us"
252253
<< std::endl;
253254
}
254255
}

csrc/all_to_all/internode.cpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,42 +36,42 @@ AllToAllInterNode::AllToAllInterNode(
3636
ROSE_ASSERT(numTokensBuffer != nullptr, "failed to allocate numTokensBuffer");
3737
cudaMemset(numTokensBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups);
3838

39-
numScatterRecvBuffer =
39+
numDispatchRecvBuffer =
4040
(uint64_t *)nvshmem_malloc(sizeof(uint64_t) * numLocalExperts * numDPGroups);
41-
ROSE_ASSERT(numScatterRecvBuffer != nullptr, "failed to allocate numScatterRecvBuffer");
42-
cudaMemset(numScatterRecvBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups);
41+
ROSE_ASSERT(numDispatchRecvBuffer != nullptr, "failed to allocate numDispatchRecvBuffer");
42+
cudaMemset(numDispatchRecvBuffer, 0, sizeof(uint64_t) * numLocalExperts * numDPGroups);
4343

44-
gatherSignalBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * maxNumTokens);
45-
ROSE_ASSERT(gatherSignalBuffer != nullptr, "failed to allocate gatherSignalBuffer");
46-
cudaMemset(gatherSignalBuffer, 0, sizeof(uint64_t) * maxNumTokens);
44+
combineSignalBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * maxNumTokens);
45+
ROSE_ASSERT(combineSignalBuffer != nullptr, "failed to allocate combineSignalBuffer");
46+
cudaMemset(combineSignalBuffer, 0, sizeof(uint64_t) * maxNumTokens);
4747

48-
gatherSyncBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * worldSize);
49-
ROSE_ASSERT(gatherSyncBuffer != nullptr, "failed to allocate gatherSyncBuffer");
50-
cudaMemset(gatherSyncBuffer, 0, sizeof(uint64_t) * worldSize);
48+
combineSyncBuffer = (uint64_t *)nvshmem_malloc(sizeof(uint64_t) * worldSize);
49+
ROSE_ASSERT(combineSyncBuffer != nullptr, "failed to allocate combineSyncBuffer");
50+
cudaMemset(combineSyncBuffer, 0, sizeof(uint64_t) * worldSize);
5151

52-
// Buffers for scatter.
52+
// Buffers for dispatch.
5353
const size_t perTokenBytes =
5454
round_up<size_t>(hiddenDimBytes + hiddenDimScaleBytes + sizeof(uint32_t), 16);
55-
xScatterIn = (std::byte *)nvshmem_malloc(maxNumTokens * perTokenBytes);
56-
ROSE_ASSERT(xScatterIn != nullptr, "failed to allocate xScatterIn");
57-
xScatterOut = (std::byte *)nvshmem_malloc(maxBatchTokens * perTokenBytes);
58-
ROSE_ASSERT(xScatterOut != nullptr, "failed to allocate xScatterOut");
55+
xDispatchIn = (std::byte *)nvshmem_malloc(maxNumTokens * perTokenBytes);
56+
ROSE_ASSERT(xDispatchIn != nullptr, "failed to allocate xDispatchIn");
57+
xDispatchOut = (std::byte *)nvshmem_malloc(maxBatchTokens * perTokenBytes);
58+
ROSE_ASSERT(xDispatchOut != nullptr, "failed to allocate xDispatchOut");
5959

60-
// Buffers for gather. The allocations are a bit wider to accommodate all
60+
// Buffers for combine. The allocations are a bit wider to accommodate all
6161
// possible data types (primarily float for testing and bfloat16 for prod).
62-
xGatherIn = (std::byte *)nvshmem_malloc(maxBatchTokens * hiddenDim * sizeof(float));
63-
ROSE_ASSERT(xGatherIn != nullptr, "failed to allocate xGatherIn");
64-
xGatherOut = (std::byte *)nvshmem_malloc(maxNumTokens * numExperts * hiddenDim * sizeof(float));
65-
ROSE_ASSERT(xGatherOut != nullptr, "failed to allocate xGatherOut");
62+
xCombineIn = (std::byte *)nvshmem_malloc(maxBatchTokens * hiddenDim * sizeof(float));
63+
ROSE_ASSERT(xCombineIn != nullptr, "failed to allocate xCombineIn");
64+
xCombineOut = (std::byte *)nvshmem_malloc(maxNumTokens * numExperts * hiddenDim * sizeof(float));
65+
ROSE_ASSERT(xCombineOut != nullptr, "failed to allocate xCombineOut");
6666
}
6767

6868
AllToAllInterNode::~AllToAllInterNode() {
6969
nvshmem_free(numTokensBuffer);
70-
nvshmem_free(numScatterRecvBuffer);
71-
nvshmem_free(gatherSignalBuffer);
72-
nvshmem_free(gatherSyncBuffer);
73-
nvshmem_free(xScatterIn);
74-
nvshmem_free(xScatterOut);
75-
nvshmem_free(xGatherIn);
76-
nvshmem_free(xGatherOut);
70+
nvshmem_free(numDispatchRecvBuffer);
71+
nvshmem_free(combineSignalBuffer);
72+
nvshmem_free(combineSyncBuffer);
73+
nvshmem_free(xDispatchIn);
74+
nvshmem_free(xDispatchOut);
75+
nvshmem_free(xCombineIn);
76+
nvshmem_free(xCombineOut);
7777
}

csrc/all_to_all/internode.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class AllToAllInterNode final : public AllToAll {
6060
/// overlapping).
6161
///
6262
/// @param stream The CUDA stream to launch the kernel on.
63-
void scatter(
63+
void dispatch(
6464
const Strided1D<int32_t> &outTokensPerExpert,
6565
const Strided2D<std::byte> &expertX,
6666
const Strided2D<std::byte> &expertXScale,
@@ -73,7 +73,7 @@ class AllToAllInterNode final : public AllToAll {
7373
cudaStream_t stream
7474
);
7575

76-
/// Launches the all-to-all gather kernel.
76+
/// Launches the all-to-all combine kernel.
7777
///
7878
/// @param outTokens The output tokens.
7979
/// Shape: [numExperts, maxNumTokens].
@@ -97,7 +97,7 @@ class AllToAllInterNode final : public AllToAll {
9797
///
9898
/// @param stream The CUDA stream to launch the kernel on.
9999
template <typename T>
100-
void gather(
100+
void combine(
101101
const Strided1D<nv_bfloat16> &outTokens,
102102
const Strided2D<uint32_t> &indices,
103103
const Strided2D<float> &weights,
@@ -111,13 +111,13 @@ class AllToAllInterNode final : public AllToAll {
111111
private:
112112
/// @section Pre-allocated symmetric shared memory workspace.
113113
uint64_t *numTokensBuffer = nullptr;
114-
uint64_t *numScatterRecvBuffer = nullptr;
115-
uint64_t *gatherSignalBuffer = nullptr;
116-
uint64_t *gatherSyncBuffer = nullptr;
117-
std::byte *xScatterIn = nullptr;
118-
std::byte *xScatterOut = nullptr;
119-
std::byte *xGatherIn = nullptr;
120-
std::byte *xGatherOut = nullptr;
114+
uint64_t *numDispatchRecvBuffer = nullptr;
115+
uint64_t *combineSignalBuffer = nullptr;
116+
uint64_t *combineSyncBuffer = nullptr;
117+
std::byte *xDispatchIn = nullptr;
118+
std::byte *xDispatchOut = nullptr;
119+
std::byte *xCombineIn = nullptr;
120+
std::byte *xCombineOut = nullptr;
121121
};
122122

123123
} // namespace pplx

csrc/all_to_all/internode_gather.cu renamed to csrc/all_to_all/internode_combine.cu

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
using namespace pplx;
1010

1111
template <typename T, size_t NUM_WARPS, bool DO_SEND, bool DO_RECV>
12-
__global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
12+
__global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel(
1313
nv_bfloat16 *outTokens,
1414
size_t outTokensStrideElem,
1515
uint32_t *indices,
@@ -34,8 +34,8 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
3434
const uint32_t *sourceIndex,
3535
const uint32_t *sourceOffset,
3636
const uint32_t *sourceGroup,
37-
uint64_t *gatherSignalBuffer,
38-
uint64_t *gatherSyncBuffer,
37+
uint64_t *combineSignalBuffer,
38+
uint64_t *combineSyncBuffer,
3939
std::byte *xBufferIn,
4040
std::byte *xBufferOut
4141
) {
@@ -50,7 +50,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
5050
if (DO_SEND) {
5151
for (unsigned i = blockIdx.x * numWarps + warpId; i < worldSize; i += gridDim.x * numWarps) {
5252
if (laneId == 0) {
53-
nvshmemx_signal_op(&gatherSyncBuffer[rank], 1, NVSHMEM_SIGNAL_SET, i);
53+
nvshmemx_signal_op(&combineSyncBuffer[rank], 1, NVSHMEM_SIGNAL_SET, i);
5454
}
5555
}
5656

@@ -84,7 +84,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
8484
const unsigned index = dstExpert * maxNumTokens + source;
8585
std::byte *dstPtr = xBufferOut + index * stride;
8686
nvshmemx_putmem_signal_nbi_warp(
87-
dstPtr, xTokenPtr, stride, &gatherSignalBuffer[source], 1, NVSHMEM_SIGNAL_ADD, dstRank
87+
dstPtr, xTokenPtr, stride, &combineSignalBuffer[source], 1, NVSHMEM_SIGNAL_ADD, dstRank
8888
);
8989
}
9090
}
@@ -100,9 +100,9 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
100100
// Compute the weighed sum of the input tokens.
101101
const size_t localNumTokens = boundM ? __ldg(boundM) : m;
102102
for (unsigned i = blockIdx.x; i < localNumTokens; i += gridDim.x) {
103-
nvshmem_uint64_wait_until(&gatherSignalBuffer[i], NVSHMEM_CMP_EQ, expertsPerToken);
103+
nvshmem_uint64_wait_until(&combineSignalBuffer[i], NVSHMEM_CMP_EQ, expertsPerToken);
104104
__syncthreads();
105-
gatherSignalBuffer[i] = 0;
105+
combineSignalBuffer[i] = 0;
106106

107107
nv_bfloat16 *dstPtr = outTokens + i * outTokensStrideElem;
108108
constexpr unsigned VEC_SIZE = 8;
@@ -134,14 +134,14 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void gatherKernel(
134134

135135
for (unsigned i = blockIdx.x * blockDim.x + threadIdx.x; i < worldSize;
136136
i += gridDim.x * blockDim.x) {
137-
nvshmem_uint64_wait_until(&gatherSyncBuffer[i], NVSHMEM_CMP_EQ, 1);
138-
gatherSyncBuffer[i] = 0;
137+
nvshmem_uint64_wait_until(&combineSyncBuffer[i], NVSHMEM_CMP_EQ, 1);
138+
combineSyncBuffer[i] = 0;
139139
}
140140
}
141141
}
142142

143143
template <typename T>
144-
void AllToAllInterNode::gather(
144+
void AllToAllInterNode::combine(
145145
const Strided1D<nv_bfloat16> &outTokens,
146146
const Strided2D<uint32_t> &indices,
147147
const Strided2D<float> &weights,
@@ -189,26 +189,26 @@ void AllToAllInterNode::gather(
189189
&sourceIndex,
190190
&sourceOffset,
191191
&sourceGroup,
192-
&gatherSignalBuffer,
193-
&gatherSyncBuffer,
194-
&xGatherIn,
195-
&xGatherOut};
192+
&combineSignalBuffer,
193+
&combineSyncBuffer,
194+
&xCombineIn,
195+
&xCombineOut};
196196

197-
nvtxRangePush("gather");
197+
nvtxRangePush("combine");
198198
switch (splitMode) {
199199
case SplitMode::SEND:
200200
CUDACHECK(cudaLaunchCooperativeKernel(
201-
&gatherKernel<T, NUM_WARPS, true, false>, dimGrid, dimBlock, args, 0, stream
201+
&combineKernel<T, NUM_WARPS, true, false>, dimGrid, dimBlock, args, 0, stream
202202
));
203203
break;
204204
case SplitMode::RECV:
205205
CUDACHECK(cudaLaunchCooperativeKernel(
206-
&gatherKernel<T, NUM_WARPS, false, true>, dimGrid, dimBlock, args, 0, stream
206+
&combineKernel<T, NUM_WARPS, false, true>, dimGrid, dimBlock, args, 0, stream
207207
));
208208
break;
209209
case SplitMode::NONE:
210210
CUDACHECK(cudaLaunchCooperativeKernel(
211-
&gatherKernel<T, NUM_WARPS, true, true>, dimGrid, dimBlock, args, 0, stream
211+
&combineKernel<T, NUM_WARPS, true, true>, dimGrid, dimBlock, args, 0, stream
212212
));
213213
break;
214214
default:
@@ -217,8 +217,8 @@ void AllToAllInterNode::gather(
217217
nvtxRangePop();
218218
}
219219

220-
#define INSTANTIATE_GATHER(T) \
221-
template void AllToAllInterNode::gather<T>( \
220+
#define INSTANTIATE_COMBINE(T) \
221+
template void AllToAllInterNode::combine<T>( \
222222
const Strided1D<nv_bfloat16> &outTokens, \
223223
const Strided2D<uint32_t> &indices, \
224224
const Strided2D<float> &weights, \
@@ -229,6 +229,6 @@ void AllToAllInterNode::gather(
229229
cudaStream_t stream \
230230
);
231231

232-
INSTANTIATE_GATHER(float)
233-
INSTANTIATE_GATHER(half)
234-
INSTANTIATE_GATHER(nv_bfloat16)
232+
INSTANTIATE_COMBINE(float)
233+
INSTANTIATE_COMBINE(half)
234+
INSTANTIATE_COMBINE(nv_bfloat16)

0 commit comments

Comments
 (0)