diff --git a/csrc/host_ir/evaluator.cpp b/csrc/host_ir/evaluator.cpp index c464f904560..5da37322588 100644 --- a/csrc/host_ir/evaluator.cpp +++ b/csrc/host_ir/evaluator.cpp @@ -376,9 +376,15 @@ void HostIrEvaluator::handle(Communication* communication) { communication->type() == CommunicationType::Reduce) ? root_val : -1; - // For Reduce, non-roots may have no output; use input for cache key + // For Broadcast and Reduce, non-roots may have no output; use input for + // cache key at::Tensor cache_buffer = output_tensor.defined() ? output_tensor : input_tensor; + // For Reduce specifically, use the input tensor for the cache key + // since the output doesn't need to be symmetric + if (communication->type() == CommunicationType::Reduce) { + cache_buffer = input_tensor; + } SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get( {.buffer = cache_buffer, .expr = communication, .root = cache_root}); postWithCudaBackend( diff --git a/csrc/multidevice/cuda_p2p.cpp b/csrc/multidevice/cuda_p2p.cpp index a31cee0c1b8..d96b033744d 100644 --- a/csrc/multidevice/cuda_p2p.cpp +++ b/csrc/multidevice/cuda_p2p.cpp @@ -857,7 +857,6 @@ void waitAllreduceWithCudaBackend( void postReduceWithCudaBackend( Communication* communication, - at::Tensor input, at::Tensor output, SymmetricMemoryForReduce* reduce_handle, CUstream stream, @@ -875,7 +874,7 @@ void postReduceWithCudaBackend( "Only SUM reduction is supported for multimem reduce; got ", communication->reduceOp()); NVF_CHECK( - input.scalar_type() == at::kFloat && + reduce_handle->inputBuffer().scalar_type() == at::kFloat && (!output.defined() || output.scalar_type() == at::kFloat), "Only float32 is supported for multimem reduce."); @@ -886,14 +885,6 @@ void postReduceWithCudaBackend( "size=", size); - // Copy input to symmetric buffer - NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync( - reduce_handle->inputBuffer().data_ptr(), - input.data_ptr(), - size, - cudaMemcpyDeviceToDevice, - stream)); - const int64_t world_size = communicator.size(); // All ranks signal ready by writing kInProgress to their own semaphore NVFUSER_CUDA_SAFE_CALL(cuStreamWriteValue32( @@ -923,10 +914,10 @@ void postReduceWithCudaBackend( cuStreamBatchMemOp(stream, world_size - 1, ops.data(), 0)); // Root launches the ld_reduce kernel - void* dst = output.defined() ? output.data_ptr() - : reduce_handle->inputBuffer().data_ptr(); + NVF_ERROR( + output.defined(), "Output must be defined for reduce on the root"); launchMulticastReduceKernel( - reduce_handle->multicastPtr(), dst, size, stream); + reduce_handle->multicastPtr(), output.data_ptr(), size, stream); // Root signals completion by writing kIdle to all non-root semaphores std::vector write_complete_ops(world_size - 1); @@ -1267,7 +1258,7 @@ void postWithCudaBackend( dynamic_cast(symmetric_memory_handle); NVF_ERROR(reduce_handle != nullptr, "Invalid reduce handle"); postReduceWithCudaBackend( - communication, input, output, reduce_handle, stream, root); + communication, output, reduce_handle, stream, root); break; } default: diff --git a/csrc/multidevice/ipc_handle.cpp b/csrc/multidevice/ipc_handle.cpp index e6855b15d1d..a489210e8f6 100644 --- a/csrc/multidevice/ipc_handle.cpp +++ b/csrc/multidevice/ipc_handle.cpp @@ -280,17 +280,14 @@ void* SymmetricMemoryForAllreduce::semaphoreUnicastPtr( SymmetricMemoryForReduce::SymmetricMemoryForReduce( Communication* communication, int64_t root, - at::Tensor output_buffer) - : size_bytes_(output_buffer.numel() * output_buffer.element_size()) { + at::Tensor buffer) + : size_bytes_(buffer.numel() * buffer.element_size()) { std::string name_suffix = "for_Communication" + std::to_string(communication->name()); std::string store_key_prefix = "nvls_reduce_" + name_suffix; - at::Tensor input_sym = SymmetricTensor::allocate( - output_buffer.sizes(), - output_buffer.scalar_type(), - output_buffer.device()); - input_sym_tensor_ = std::make_unique(input_sym); + // We assume the input buffer is already a symmetric tensor + input_sym_tensor_ = std::make_unique(buffer); input_sym_tensor_->setupRemoteHandles(store_key_prefix + "_input_unicast"); MulticastProtocol protocol = getMulticastProtocol(); @@ -303,7 +300,7 @@ SymmetricMemoryForReduce::SymmetricMemoryForReduce( at::Tensor semaphore = SymmetricTensor::allocate( /*sizes=*/at::IntArrayRef({1}), /*dtype=*/at::ScalarType::Int, - /*device=*/output_buffer.device()); + /*device=*/buffer.device()); // Initialize the semaphore to kIdle IpcSemaphore init_value = IpcSemaphore::kIdle; diff --git a/tests/cpp/test_multidevice_stream_parallel_type.cpp b/tests/cpp/test_multidevice_stream_parallel_type.cpp index 05870e31f03..fd713d48621 100644 --- a/tests/cpp/test_multidevice_stream_parallel_type.cpp +++ b/tests/cpp/test_multidevice_stream_parallel_type.cpp @@ -764,6 +764,12 @@ TEST_P(RSMatmulTest, ReduceScatterReduceBased) { C->axis(1)->parallelize(ParallelType::DIDx); C->axis(0)->parallelize(ParallelType::Stream); + // Set the matmul's output to be symmetric in order to use NVLS multimem + // reduce. + if (communicator_backend == CommunicatorBackend::kCuda) { + C_unreduced->setMemoryType(MemoryType::Symmetric); + } + MultiDeviceExecutorParams params; params.lower.communicator_backend = communicator_backend; params.lower.offset_stream_indexing_by_rank = false; // Will fail if true