@@ -48,7 +48,7 @@ HostIrEvaluator::HostIrEvaluator(
4848 expr_evaluator_(),
4949 my_local_device_index_(
5050 communicator_ == nullptr ? 0 : communicator_->local_rank ()),
51- ipc_handle_cache_(expr_evaluator_),
51+ ipc_handle_cache_(& expr_evaluator_),
5252 multicast_handle_cache_() {
5353 const DeviceIdxType device_index =
5454 communicator_ == nullptr ? 0 : communicator_->deviceId ();
@@ -364,16 +364,27 @@ void HostIrEvaluator::handle(Communication* communication) {
364364 .stream ());
365365 NVF_ERROR (
366366 communication->type () == CommunicationType::Broadcast ||
367- communication->type () == CommunicationType::Allgather,
368- " Invalid communication type, expected Broadcast or Allgather, got: " ,
367+ communication->type () == CommunicationType::Allgather ||
368+ communication->type () == CommunicationType::Allreduce ||
369+ communication->type () == CommunicationType::Reduce,
370+ " Invalid communication type for CUDA backend, got: " ,
369371 communication->type ());
370372 int64_t root_val =
371373 expr_evaluator_.evaluate (communication->root ()).as <int64_t >();
374+ int64_t cache_root =
375+ (communication->type () == CommunicationType::Broadcast ||
376+ communication->type () == CommunicationType::Reduce)
377+ ? root_val
378+ : -1 ;
379+ // For Reduce, non-roots may have no output; use input for cache key
380+ at::Tensor cache_buffer =
381+ output_tensor.defined () ? output_tensor : input_tensor;
372382 SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get (
373- {.buffer = output_tensor , .expr = communication, .root = root_val });
383+ {.buffer = cache_buffer , .expr = communication, .root = cache_root });
374384 postWithCudaBackend (
375385 communication,
376386 input_tensor,
387+ output_tensor,
377388 multicast_handle,
378389 current_stream,
379390 root_val);
@@ -492,15 +503,25 @@ void HostIrEvaluator::handle(Wait* wait) {
492503 communication && communication->backend () == CommunicatorBackend::kCuda ) {
493504 NVF_ERROR (
494505 communication->type () == CommunicationType::Broadcast ||
495- communication->type () == CommunicationType::Allgather,
496- " Invalid communication type, only Broadcast and Allgather are "
497- " supported with cuda backend, got: " ,
506+ communication->type () == CommunicationType::Allgather ||
507+ communication->type () == CommunicationType::Allreduce ||
508+ communication->type () == CommunicationType::Reduce,
509+ " Invalid communication type for CUDA backend, got: " ,
498510 communication->type ());
499511 at::Tensor output_tensor = getKnownTensorOrUndefined (communication->out ());
512+ at::Tensor input_tensor =
513+ getKnownTensorOrUndefined (communication->input (0 ));
500514 int64_t root_val =
501515 expr_evaluator_.evaluate (communication->root ()).as <int64_t >();
516+ int64_t cache_root =
517+ (communication->type () == CommunicationType::Broadcast ||
518+ communication->type () == CommunicationType::Reduce)
519+ ? root_val
520+ : -1 ;
521+ at::Tensor cache_buffer =
522+ output_tensor.defined () ? output_tensor : input_tensor;
502523 SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get (
503- {.buffer = output_tensor , .expr = communication, .root = root_val });
524+ {.buffer = cache_buffer , .expr = communication, .root = cache_root });
504525 waitWithCudaBackend (
505526 communication, multicast_handle, current_stream, root_val);
506527 } else {
0 commit comments