Skip to content

Commit 0e9a309

Browse files
committed
Remove extra copy by assuming the cuda nvls reduce input tensor is symmetric
1 parent 5c17125 commit 0e9a309

File tree

4 files changed

+20
-23
lines changed

4 files changed

+20
-23
lines changed

csrc/host_ir/evaluator.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,9 +376,14 @@ void HostIrEvaluator::handle(Communication* communication) {
376376
communication->type() == CommunicationType::Reduce)
377377
? root_val
378378
: -1;
379-
// For Reduce, non-roots may have no output; use input for cache key
379+
// For Broadcast and Reduce, non-roots may have no output; use input for cache key
380380
at::Tensor cache_buffer =
381381
output_tensor.defined() ? output_tensor : input_tensor;
382+
// For Reduce specifically, use the input tensor for the cache key
383+
// since the output doesn't need to be symmetric
384+
if (communication->type() == CommunicationType::Reduce) {
385+
cache_buffer = input_tensor;
386+
}
382387
SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get(
383388
{.buffer = cache_buffer, .expr = communication, .root = cache_root});
384389
postWithCudaBackend(

csrc/multidevice/cuda_p2p.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,6 @@ void waitAllreduceWithCudaBackend(
857857

858858
void postReduceWithCudaBackend(
859859
Communication* communication,
860-
at::Tensor input,
861860
at::Tensor output,
862861
SymmetricMemoryForReduce* reduce_handle,
863862
CUstream stream,
@@ -875,7 +874,7 @@ void postReduceWithCudaBackend(
875874
"Only SUM reduction is supported for multimem reduce; got ",
876875
communication->reduceOp());
877876
NVF_CHECK(
878-
input.scalar_type() == at::kFloat &&
877+
reduce_handle->inputBuffer().scalar_type() == at::kFloat &&
879878
(!output.defined() || output.scalar_type() == at::kFloat),
880879
"Only float32 is supported for multimem reduce.");
881880

@@ -886,14 +885,6 @@ void postReduceWithCudaBackend(
886885
"size=",
887886
size);
888887

889-
// Copy input to symmetric buffer
890-
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
891-
reduce_handle->inputBuffer().data_ptr(),
892-
input.data_ptr(),
893-
size,
894-
cudaMemcpyDeviceToDevice,
895-
stream));
896-
897888
const int64_t world_size = communicator.size();
898889
// All ranks signal ready by writing kInProgress to their own semaphore
899890
NVFUSER_CUDA_SAFE_CALL(cuStreamWriteValue32(
@@ -923,10 +914,9 @@ void postReduceWithCudaBackend(
923914
cuStreamBatchMemOp(stream, world_size - 1, ops.data(), 0));
924915

925916
// Root launches the ld_reduce kernel
926-
void* dst = output.defined() ? output.data_ptr()
927-
: reduce_handle->inputBuffer().data_ptr();
917+
NVF_ERROR(output.defined(), "Output must be defined for reduce on the root");
928918
launchMulticastReduceKernel(
929-
reduce_handle->multicastPtr(), dst, size, stream);
919+
reduce_handle->multicastPtr(), output.data_ptr(), size, stream);
930920

931921
// Root signals completion by writing kIdle to all non-root semaphores
932922
std::vector<CUstreamBatchMemOpParams> write_complete_ops(world_size - 1);
@@ -1267,7 +1257,7 @@ void postWithCudaBackend(
12671257
dynamic_cast<SymmetricMemoryForReduce*>(symmetric_memory_handle);
12681258
NVF_ERROR(reduce_handle != nullptr, "Invalid reduce handle");
12691259
postReduceWithCudaBackend(
1270-
communication, input, output, reduce_handle, stream, root);
1260+
communication, output, reduce_handle, stream, root);
12711261
break;
12721262
}
12731263
default:

csrc/multidevice/ipc_handle.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,14 @@ void* SymmetricMemoryForAllreduce::semaphoreUnicastPtr(
280280
SymmetricMemoryForReduce::SymmetricMemoryForReduce(
281281
Communication* communication,
282282
int64_t root,
283-
at::Tensor output_buffer)
284-
: size_bytes_(output_buffer.numel() * output_buffer.element_size()) {
283+
at::Tensor buffer)
284+
: size_bytes_(buffer.numel() * buffer.element_size()) {
285285
std::string name_suffix =
286286
"for_Communication" + std::to_string(communication->name());
287287
std::string store_key_prefix = "nvls_reduce_" + name_suffix;
288288

289-
at::Tensor input_sym = SymmetricTensor::allocate(
290-
output_buffer.sizes(),
291-
output_buffer.scalar_type(),
292-
output_buffer.device());
293-
input_sym_tensor_ = std::make_unique<SymmetricTensor>(input_sym);
289+
// We assume the input buffer is already a symmetric tensor
290+
input_sym_tensor_ = std::make_unique<SymmetricTensor>(buffer);
294291
input_sym_tensor_->setupRemoteHandles(store_key_prefix + "_input_unicast");
295292

296293
MulticastProtocol protocol = getMulticastProtocol();
@@ -303,7 +300,7 @@ SymmetricMemoryForReduce::SymmetricMemoryForReduce(
303300
at::Tensor semaphore = SymmetricTensor::allocate(
304301
/*sizes=*/at::IntArrayRef({1}),
305302
/*dtype=*/at::ScalarType::Int,
306-
/*device=*/output_buffer.device());
303+
/*device=*/buffer.device());
307304

308305
// Initialize the semaphore to kIdle
309306
IpcSemaphore init_value = IpcSemaphore::kIdle;

tests/cpp/test_multidevice_stream_parallel_type.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,11 @@ TEST_P(RSMatmulTest, ReduceScatterReduceBased) {
764764
C->axis(1)->parallelize(ParallelType::DIDx);
765765
C->axis(0)->parallelize(ParallelType::Stream);
766766

767+
// Set the matmul's output to be symmetric in order to use NVLS multimem reduce.
768+
if (communicator_backend == CommunicatorBackend::kCuda) {
769+
C_unreduced->setMemoryType(MemoryType::Symmetric);
770+
}
771+
767772
MultiDeviceExecutorParams params;
768773
params.lower.communicator_backend = communicator_backend;
769774
params.lower.offset_stream_indexing_by_rank = false; // Will fail if true

0 commit comments

Comments
 (0)