Skip to content

Commit be55fd7

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
Combine: Support Float16 datatype for OutTokens (#3)
Support Float16 datatype for OutTokens - Replace hard coded `nv_bfloat16` with a Template type for outTokens in the combine kernel - Updated python and cpp `all_to_all` tests - Updated python and cpp benchmarks --------- Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent f525ab5 commit be55fd7

File tree

7 files changed

+134
-68
lines changed

7 files changed

+134
-68
lines changed

csrc/all_to_all/bench_all_to_all.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Time average(const std::vector<float> &timesUs) {
4040
return std::make_pair(mean, stddev);
4141
}
4242

43-
template <typename T>
43+
template <typename T, typename U>
4444
std::pair<Time, Time> benchmark(
4545
unsigned repeat,
4646
const BenchConfig &config,
@@ -65,7 +65,7 @@ std::pair<Time, Time> benchmark(
6565
DeviceBuffer<float> outExpertScaleDevice(
6666
expertsPerRank * config.numTokens * numPEs * data.hiddenDimScale
6767
);
68-
DeviceBuffer<nv_bfloat16> outTokensDevice(config.numTokens * data.hiddenDim);
68+
DeviceBuffer<U> outTokensDevice(config.numTokens * data.hiddenDim);
6969
DeviceBuffer<T> xDevice(data.x);
7070
DeviceBuffer<float> xScaleDevice(data.xScale);
7171
DeviceBuffer<uint32_t> indicesDevice(data.indices);
@@ -128,8 +128,8 @@ std::pair<Time, Time> benchmark(
128128

129129
CUDACHECK(cudaEventRecord(std::get<1>(events[i]), stream));
130130

131-
allToAll.combine<T>(
132-
Strided1D<nv_bfloat16>(outTokensDevice, config.hiddenDim),
131+
allToAll.combine<T, U>(
132+
Strided1D<U>(outTokensDevice, config.hiddenDim),
133133
Strided2D<uint32_t>(indicesDevice, 1, config.expertsPerToken),
134134
Strided2D<float>(weightsDevice, 1, config.expertsPerToken),
135135
Strided2D<T>(
@@ -239,19 +239,33 @@ int main(int argc, char **argv) {
239239
{128, 256, 8, 7168, 128},
240240
};
241241

242-
for (const auto &config : configs) {
243-
auto [dispatch, combine] = benchmark<nv_bfloat16>(10, config, currentPE, numPEs, stream);
244-
if (currentPE == 0) {
245-
auto [dispatchMean, dispatchStddev] = dispatch;
246-
auto [combineMean, combineStddev] = combine;
247-
std::cout << std::setw(3) << config.numTokens << " " << std::setw(3) << config.numExperts
248-
<< " " << std::setw(3) << config.expertsPerToken << " " << std::setw(4)
249-
<< config.hiddenDim << " " << std::fixed << std::setprecision(3)
242+
auto maybe_print_bench_results = [](int const myPE,
243+
BenchConfig const &config,
244+
Time const &dispatch_time,
245+
Time const &combine_time,
246+
std::string const description = "") {
247+
if (myPE == 0) {
248+
auto [dispatchMean, dispatchStddev] = dispatch_time;
249+
auto [combineMean, combineStddev] = combine_time;
250+
std::cout << description << std::setw(6) << config.numTokens << " " << std::setw(3)
251+
<< config.numExperts << " " << std::setw(3) << config.expertsPerToken << " "
252+
<< std::setw(4) << config.hiddenDim << " " << std::fixed << std::setprecision(3)
250253
<< "Dispatch: " << std::setw(10) << dispatchMean << "us ± " << dispatchStddev
251254
<< "us "
252255
<< "Combine: " << std::setw(10) << combineMean << "us ± " << combineStddev << "us"
253256
<< std::endl;
254257
}
258+
};
259+
260+
for (const auto &config : configs) {
261+
auto [dispatch, combine] =
262+
benchmark<nv_bfloat16, nv_bfloat16>(10, config, currentPE, numPEs, stream);
263+
maybe_print_bench_results(currentPE, config, dispatch, combine, "nv_bfloat16->nv_bfloat16:");
264+
}
265+
266+
for (const auto &config : configs) {
267+
auto [dispatch, combine] = benchmark<half, half>(10, config, currentPE, numPEs, stream);
268+
maybe_print_bench_results(currentPE, config, dispatch, combine, "half->half:");
255269
}
256270

257271
// Cleanup.

csrc/all_to_all/internode.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ class AllToAllInterNode final : public AllToAll {
9696
/// Shape: [1].
9797
///
9898
/// @param stream The CUDA stream to launch the kernel on.
99-
template <typename T>
99+
template <typename T, typename U>
100100
void combine(
101-
const Strided1D<nv_bfloat16> &outTokens,
101+
const Strided1D<U> &outTokens,
102102
const Strided2D<uint32_t> &indices,
103103
const Strided2D<float> &weights,
104104
const Strided2D<T> &expertX,

csrc/all_to_all/internode_combine.cu

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
using namespace pplx;
1010

11-
template <typename T, size_t NUM_WARPS, bool DO_SEND, bool DO_RECV>
11+
template <typename T, typename U, size_t NUM_WARPS, bool DO_SEND, bool DO_RECV>
1212
__global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel(
13-
nv_bfloat16 *outTokens,
13+
U *outTokens,
1414
size_t outTokensStrideElem,
1515
uint32_t *indices,
1616
size_t indicesStrideElem,
@@ -104,7 +104,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel(
104104
__syncthreads();
105105
combineSignalBuffer[i] = 0;
106106

107-
nv_bfloat16 *dstPtr = outTokens + i * outTokensStrideElem;
107+
U *dstPtr = outTokens + i * outTokensStrideElem;
108108
constexpr unsigned VEC_SIZE = 8;
109109
for (unsigned j = threadIdx.x * VEC_SIZE; j < hiddenDim; j += blockDim.x * VEC_SIZE) {
110110
float sum[VEC_SIZE];
@@ -140,9 +140,9 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void combineKernel(
140140
}
141141
}
142142

143-
template <typename T>
143+
template <typename T, typename U>
144144
void AllToAllInterNode::combine(
145-
const Strided1D<nv_bfloat16> &outTokens,
145+
const Strided1D<U> &outTokens,
146146
const Strided2D<uint32_t> &indices,
147147
const Strided2D<float> &weights,
148148
const Strided2D<T> &expertX,
@@ -165,7 +165,7 @@ void AllToAllInterNode::combine(
165165
dim3 dimBlock(NUM_WARPS * 32, 1, 1);
166166

167167
void *args[] = {
168-
const_cast<nv_bfloat16 **>(&outTokens.data),
168+
const_cast<U **>(&outTokens.data),
169169
const_cast<size_t *>(&outTokens.strideElem),
170170
const_cast<uint32_t **>(&indices.data),
171171
const_cast<size_t *>(&indices.strideElem),
@@ -198,17 +198,17 @@ void AllToAllInterNode::combine(
198198
switch (splitMode) {
199199
case SplitMode::SEND:
200200
CUDACHECK(cudaLaunchCooperativeKernel(
201-
(void *)&combineKernel<T, NUM_WARPS, true, false>, dimGrid, dimBlock, args, 0, stream
201+
(void *)&combineKernel<T, U, NUM_WARPS, true, false>, dimGrid, dimBlock, args, 0, stream
202202
));
203203
break;
204204
case SplitMode::RECV:
205205
CUDACHECK(cudaLaunchCooperativeKernel(
206-
(void *)&combineKernel<T, NUM_WARPS, false, true>, dimGrid, dimBlock, args, 0, stream
206+
(void *)&combineKernel<T, U, NUM_WARPS, false, true>, dimGrid, dimBlock, args, 0, stream
207207
));
208208
break;
209209
case SplitMode::NONE:
210210
CUDACHECK(cudaLaunchCooperativeKernel(
211-
(void *)&combineKernel<T, NUM_WARPS, true, true>, dimGrid, dimBlock, args, 0, stream
211+
(void *)&combineKernel<T, U, NUM_WARPS, true, true>, dimGrid, dimBlock, args, 0, stream
212212
));
213213
break;
214214
default:
@@ -217,9 +217,9 @@ void AllToAllInterNode::combine(
217217
nvtxRangePop();
218218
}
219219

220-
#define INSTANTIATE_COMBINE(T) \
221-
template void AllToAllInterNode::combine<T>( \
222-
const Strided1D<nv_bfloat16> &outTokens, \
220+
#define INSTANTIATE_COMBINE(T, U) \
221+
template void AllToAllInterNode::combine<T, U>( \
222+
const Strided1D<U> &outTokens, \
223223
const Strided2D<uint32_t> &indices, \
224224
const Strided2D<float> &weights, \
225225
const Strided2D<T> &expertX, \
@@ -229,6 +229,9 @@ void AllToAllInterNode::combine(
229229
cudaStream_t stream \
230230
);
231231

232-
INSTANTIATE_COMBINE(float)
233-
INSTANTIATE_COMBINE(half)
234-
INSTANTIATE_COMBINE(nv_bfloat16)
232+
INSTANTIATE_COMBINE(float, nv_bfloat16)
233+
INSTANTIATE_COMBINE(half, nv_bfloat16)
234+
INSTANTIATE_COMBINE(nv_bfloat16, nv_bfloat16)
235+
INSTANTIATE_COMBINE(float, half)
236+
INSTANTIATE_COMBINE(half, half)
237+
INSTANTIATE_COMBINE(nv_bfloat16, half)

csrc/all_to_all/test_all_to_all.cpp

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
using namespace pplx;
2121

22-
template <typename T, typename Kernel>
22+
template <typename T, typename U, typename Kernel>
2323
bool testDispatchCombine(
2424
cudaStream_t stream,
2525
unsigned dpRank,
@@ -75,7 +75,7 @@ bool testDispatchCombine(
7575
DeviceBuffer<float> outExpertScaleDevice(
7676
expertsPerRank * maxNumTokens * numDPGroups * rank.hiddenDimScale
7777
);
78-
DeviceBuffer<nv_bfloat16> outTokensDevice(maxNumTokens * hiddenDim);
78+
DeviceBuffer<U> outTokensDevice(maxNumTokens * hiddenDim);
7979

8080
const size_t hiddenDimBytes = rank.hiddenDim * sizeof(T);
8181
const size_t hiddenDimScaleBytes = rank.hiddenDimScale * sizeof(float);
@@ -113,7 +113,7 @@ bool testDispatchCombine(
113113
CUDACHECK(cudaStreamSynchronize(stream));
114114

115115
allToAll.combine(
116-
Strided1D<nv_bfloat16>(outTokensDevice, hiddenDim),
116+
Strided1D<U>(outTokensDevice, hiddenDim),
117117
Strided2D<uint32_t>(indicesDevice, 1, expertsPerToken),
118118
Strided2D<float>(weightsDevice, 1, expertsPerToken),
119119
Strided2D<T>(outExpertDevice, hiddenDim, hiddenDim * maxNumTokens * numDPGroups),
@@ -127,7 +127,7 @@ bool testDispatchCombine(
127127
HostBuffer<int32_t> outNumTokensPerExpertHost(outTokensPerExpertDevice);
128128
HostBuffer<T> outExpertHost(outExpertDevice);
129129
HostBuffer<float> outExpertScaleHost(outExpertScaleDevice);
130-
HostBuffer<nv_bfloat16> outTokensHost(outTokensDevice);
130+
HostBuffer<U> outTokensHost(outTokensDevice);
131131

132132
// Print the results.
133133
for (unsigned i = 0; i < epSize; ++i) {
@@ -322,10 +322,22 @@ int main(int argc, char **argv) {
322322

323323
// Run the tests.
324324
int exit_code = EXIT_SUCCESS;
325-
if (!testDispatchCombine<float, AllToAllInterNode>(stream, rank / 2, 2, rank, world_size)) {
325+
if (!testDispatchCombine<float, nv_bfloat16, AllToAllInterNode>(
326+
stream, rank / 2, 2, rank, world_size
327+
)) {
326328
exit_code = EXIT_FAILURE;
327329
}
328-
if (!testDispatchCombine<nv_bfloat16, AllToAllInterNode>(stream, rank / 2, 2, rank, world_size)) {
330+
if (!testDispatchCombine<nv_bfloat16, nv_bfloat16, AllToAllInterNode>(
331+
stream, rank / 2, 2, rank, world_size
332+
)) {
333+
exit_code = EXIT_FAILURE;
334+
}
335+
if (!testDispatchCombine<float, half, AllToAllInterNode>(stream, rank / 2, 2, rank, world_size)) {
336+
exit_code = EXIT_FAILURE;
337+
}
338+
if (!testDispatchCombine<nv_bfloat16, half, AllToAllInterNode>(
339+
stream, rank / 2, 2, rank, world_size
340+
)) {
329341
exit_code = EXIT_FAILURE;
330342
}
331343

csrc/bindings/all_to_all_ops.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ void combine(
136136
bool doRecv
137137
) {
138138
_CHECK_TENSOR(2, outTokens);
139-
TORCH_CHECK(outTokens.scalar_type() == at::kBFloat16, "outTokens must be of type BFloat16");
139+
TORCH_CHECK(
140+
outTokens.scalar_type() == at::kBFloat16 || outTokens.scalar_type() == at::kHalf,
141+
"outTokens must be of type BFloat16 or Float16"
142+
);
140143
_CHECK_TENSOR(2, indices);
141144
TORCH_CHECK(indices.scalar_type() == at::kUInt32, "indices must be of type UInt32");
142145
_CHECK_TENSOR(2, weights);
@@ -149,9 +152,9 @@ void combine(
149152
}
150153

151154
auto *all_to_all = (AllToAllInterNode *)ptr;
152-
auto run = [&]<typename T>() {
153-
all_to_all->combine<T>(
154-
Strided1D<nv_bfloat16>((nv_bfloat16 *)outTokens.data_ptr(), (size_t)outTokens.stride(0)),
155+
auto run = [&]<typename T, typename U>() {
156+
all_to_all->combine<T, U>(
157+
Strided1D<U>((U *)outTokens.data_ptr(), (size_t)outTokens.stride(0)),
155158
Strided2D<uint32_t>(
156159
indices.data_ptr<uint32_t>(), (size_t)indices.stride(1), (size_t)indices.stride(0)
157160
),
@@ -166,15 +169,28 @@ void combine(
166169
);
167170
};
168171

172+
auto out_type_switch = [&]<typename T>(at::ScalarType const &out_dtype) {
173+
switch (out_dtype) {
174+
case at::kBFloat16:
175+
run.operator()<T, nv_bfloat16>();
176+
break;
177+
case at::kHalf:
178+
run.operator()<T, half>();
179+
break;
180+
default:
181+
TORCH_CHECK(false, "Unsupported dtype for outTokens");
182+
}
183+
};
184+
169185
switch (expertY.scalar_type()) {
170186
case at::kFloat:
171-
run.operator()<float>();
187+
out_type_switch.operator()<float>(outTokens.scalar_type());
172188
break;
173189
case at::kBFloat16:
174-
run.operator()<nv_bfloat16>();
190+
out_type_switch.operator()<nv_bfloat16>(outTokens.scalar_type());
175191
break;
176192
case at::kHalf:
177-
run.operator()<half>();
193+
out_type_switch.operator()<half>(outTokens.scalar_type());
178194
break;
179195
default:
180196
TORCH_CHECK(false, "Unsupported dtype for expertY");

tests/bench_all_to_all.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -234,30 +234,32 @@ def _worker_bench_all_to_all(
234234
pgi: ProcessGroupInfo,
235235
dp_size: int,
236236
in_dtype_str: str,
237+
out_dtype_str: str,
237238
) -> None:
238239
uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
239240
torch.distributed.broadcast(uid, src=0)
240241
nvshmem_init(uid, pgi.rank, pgi.world_size)
241242

242243
in_dtype = getattr(torch, in_dtype_str)
244+
out_dtype = getattr(torch, out_dtype_str)
243245
assert isinstance(in_dtype, torch.dtype)
244246
configs = [
245247
# V2-Lite: 64 Experts, 6 Experts per Token, 2048 Hidden Dim
246-
MoEConfig(64, 6, 2048, 1, in_dtype),
247-
MoEConfig(64, 6, 2048, 4, in_dtype),
248-
MoEConfig(64, 6, 2048, 8, in_dtype),
249-
MoEConfig(64, 6, 2048, 16, in_dtype),
250-
MoEConfig(64, 6, 2048, 32, in_dtype),
251-
MoEConfig(64, 6, 2048, 64, in_dtype),
252-
MoEConfig(64, 6, 2048, 128, in_dtype),
248+
MoEConfig(64, 6, 2048, 1, in_dtype, out_dtype),
249+
MoEConfig(64, 6, 2048, 4, in_dtype, out_dtype),
250+
MoEConfig(64, 6, 2048, 8, in_dtype, out_dtype),
251+
MoEConfig(64, 6, 2048, 16, in_dtype, out_dtype),
252+
MoEConfig(64, 6, 2048, 32, in_dtype, out_dtype),
253+
MoEConfig(64, 6, 2048, 64, in_dtype, out_dtype),
254+
MoEConfig(64, 6, 2048, 128, in_dtype, out_dtype),
253255
# R1 : 256 Experts, 8 Experts per Token, 7168 Hidden Dim
254-
MoEConfig(256, 8, 7168, 1, in_dtype),
255-
MoEConfig(256, 8, 7168, 4, in_dtype),
256-
MoEConfig(256, 8, 7168, 8, in_dtype),
257-
MoEConfig(256, 8, 7168, 16, in_dtype),
258-
MoEConfig(256, 8, 7168, 32, in_dtype),
259-
MoEConfig(256, 8, 7168, 64, in_dtype),
260-
MoEConfig(256, 8, 7168, 128, in_dtype),
256+
MoEConfig(256, 8, 7168, 1, in_dtype, out_dtype),
257+
MoEConfig(256, 8, 7168, 4, in_dtype, out_dtype),
258+
MoEConfig(256, 8, 7168, 8, in_dtype, out_dtype),
259+
MoEConfig(256, 8, 7168, 16, in_dtype, out_dtype),
260+
MoEConfig(256, 8, 7168, 32, in_dtype, out_dtype),
261+
MoEConfig(256, 8, 7168, 64, in_dtype, out_dtype),
262+
MoEConfig(256, 8, 7168, 128, in_dtype, out_dtype),
261263
]
262264

263265
header = [
@@ -340,18 +342,26 @@ def main() -> None:
340342
parser.add_argument("--dp-size", type=int, default=1)
341343
parser.add_argument(
342344
"--in-dtype",
343-
choices=["bfloat16", "float8_e4m3fn"],
345+
choices=["bfloat16", "float16", "float8_e4m3fn"],
344346
default="float8_e4m3fn",
345347
)
348+
parser.add_argument(
349+
"--out-dtype",
350+
choices=["bfloat16", "float16"],
351+
default="bfloat16",
352+
)
346353
args = parser.parse_args()
347354
dp_size = int(args.dp_size)
348355
in_dtype = str(args.in_dtype)
356+
out_dtype = str(args.out_dtype)
349357

350358
if "MASTER_ADDR" in os.environ:
351-
parallel_launch_from_env(_worker_bench_all_to_all, dp_size, in_dtype)
359+
parallel_launch_from_env(_worker_bench_all_to_all, dp_size, in_dtype, out_dtype)
352360
else:
353361
world_size = torch.cuda.device_count()
354-
parallel_launch(world_size, _worker_bench_all_to_all, dp_size, in_dtype)
362+
parallel_launch(
363+
world_size, _worker_bench_all_to_all, dp_size, in_dtype, out_dtype
364+
)
355365

356366

357367
if __name__ == "__main__":

0 commit comments

Comments
 (0)