Skip to content

Commit 5c17125

Browse files
nsarkagreptile-apps[bot]wujingyue
authored
Reduce and Allreduce NVLS implementations for the cuda backend (#6038)
Built on top of #5620. Adds reduce and allreduce NVLS implementations. Both use the same ld_reduce kernel and synchronize using a symmetric integer tensor as a semaphore --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
1 parent 617fa07 commit 5c17125

13 files changed

+1316
-93
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,7 @@ list(APPEND NVFUSER_RUNTIME_FILES
12431243
${NVFUSER_ROOT}/runtime/mbarrier.cu
12441244
${NVFUSER_ROOT}/runtime/memory.cu
12451245
${NVFUSER_ROOT}/runtime/multicast.cu
1246+
${NVFUSER_ROOT}/runtime/multicast_reduce.cu
12461247
${NVFUSER_ROOT}/runtime/alltoallv.cu
12471248
${NVFUSER_ROOT}/runtime/tma_copy.cu
12481249
${NVFUSER_ROOT}/runtime/random_numbers.cu

csrc/host_ir/evaluator.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ HostIrEvaluator::HostIrEvaluator(
4848
expr_evaluator_(),
4949
my_local_device_index_(
5050
communicator_ == nullptr ? 0 : communicator_->local_rank()),
51-
ipc_handle_cache_(expr_evaluator_),
51+
ipc_handle_cache_(&expr_evaluator_),
5252
multicast_handle_cache_() {
5353
const DeviceIdxType device_index =
5454
communicator_ == nullptr ? 0 : communicator_->deviceId();
@@ -364,16 +364,27 @@ void HostIrEvaluator::handle(Communication* communication) {
364364
.stream());
365365
NVF_ERROR(
366366
communication->type() == CommunicationType::Broadcast ||
367-
communication->type() == CommunicationType::Allgather,
368-
"Invalid communication type, expected Broadcast or Allgather, got: ",
367+
communication->type() == CommunicationType::Allgather ||
368+
communication->type() == CommunicationType::Allreduce ||
369+
communication->type() == CommunicationType::Reduce,
370+
"Invalid communication type for CUDA backend, got: ",
369371
communication->type());
370372
int64_t root_val =
371373
expr_evaluator_.evaluate(communication->root()).as<int64_t>();
374+
int64_t cache_root =
375+
(communication->type() == CommunicationType::Broadcast ||
376+
communication->type() == CommunicationType::Reduce)
377+
? root_val
378+
: -1;
379+
// For Reduce, non-roots may have no output; use input for cache key
380+
at::Tensor cache_buffer =
381+
output_tensor.defined() ? output_tensor : input_tensor;
372382
SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get(
373-
{.buffer = output_tensor, .expr = communication, .root = root_val});
383+
{.buffer = cache_buffer, .expr = communication, .root = cache_root});
374384
postWithCudaBackend(
375385
communication,
376386
input_tensor,
387+
output_tensor,
377388
multicast_handle,
378389
current_stream,
379390
root_val);
@@ -492,15 +503,25 @@ void HostIrEvaluator::handle(Wait* wait) {
492503
communication && communication->backend() == CommunicatorBackend::kCuda) {
493504
NVF_ERROR(
494505
communication->type() == CommunicationType::Broadcast ||
495-
communication->type() == CommunicationType::Allgather,
496-
"Invalid communication type, only Broadcast and Allgather are "
497-
"supported with cuda backend, got: ",
506+
communication->type() == CommunicationType::Allgather ||
507+
communication->type() == CommunicationType::Allreduce ||
508+
communication->type() == CommunicationType::Reduce,
509+
"Invalid communication type for CUDA backend, got: ",
498510
communication->type());
499511
at::Tensor output_tensor = getKnownTensorOrUndefined(communication->out());
512+
at::Tensor input_tensor =
513+
getKnownTensorOrUndefined(communication->input(0));
500514
int64_t root_val =
501515
expr_evaluator_.evaluate(communication->root()).as<int64_t>();
516+
int64_t cache_root =
517+
(communication->type() == CommunicationType::Broadcast ||
518+
communication->type() == CommunicationType::Reduce)
519+
? root_val
520+
: -1;
521+
at::Tensor cache_buffer =
522+
output_tensor.defined() ? output_tensor : input_tensor;
502523
SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get(
503-
{.buffer = output_tensor, .expr = communication, .root = root_val});
524+
{.buffer = cache_buffer, .expr = communication, .root = cache_root});
504525
waitWithCudaBackend(
505526
communication, multicast_handle, current_stream, root_val);
506527
} else {

0 commit comments

Comments
 (0)