Skip to content

Remove extra copy by assuming reduce tensor is symmetric#6040

Open
nsarka wants to merge 2 commits intoNVIDIA:mainfrom
nsarka:nsarka/nvls-buffer
Open

Remove extra copy by assuming reduce tensor is symmetric#6040
nsarka wants to merge 2 commits intoNVIDIA:mainfrom
nsarka:nsarka/nvls-buffer

Conversation

@nsarka
Copy link
Copy Markdown
Member

@nsarka nsarka commented Mar 25, 2026

For the cuda backend. I will update the PR shortly to do the same for allreduce.

@nsarka nsarka requested review from samnordmann and wujingyue March 25, 2026 21:42
@nsarka nsarka self-assigned this Mar 25, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR removes an unnecessary cudaMemcpyAsync from the CUDA-backend reduce path by assuming the input tensor is already allocated in symmetric (NVLS) memory, eliminating a device-to-device copy on every reduce call.

Key changes:

  • ipc_handle.cpp: SymmetricMemoryForReduce no longer allocates a separate symmetric tensor; it wraps the caller-supplied buffer directly via SymmetricTensor (which validates symmetric alignment at construction time).
  • cuda_p2p.cpp: The staging cudaMemcpyAsync and the input parameter to postReduceWithCudaBackend are removed. The old silent fallback of writing reduce output into the input buffer on root (when no output was defined) is replaced with an explicit NVF_ERROR.
  • evaluator.cpp: The symmetric-memory cache key for Reduce is now unconditionally the input_tensor (reflecting that the input, not the output, is the symmetric buffer).
  • test_multidevice_stream_parallel_type.cpp: The matmul output C_unreduced is marked MemoryType::Symmetric when using the CUDA backend, satisfying the new requirement.
  • The header ipc_handle.h still declares the third constructor parameter as output_buffer, which is now a stale and misleading name — the buffer is the symmetric input tensor.

Confidence Score: 5/5

  • Safe to merge; the optimization is sound and runtime validation in SymmetricTensor::validate() guards the new assumption.
  • The change is a targeted performance optimization that removes one device-to-device copy per reduce operation. The invariant (input must be symmetric) is enforced at handle-construction time by SymmetricTensor::validate(), the test is correctly updated, and the only remaining issue is a stale parameter name in the header declaration — a non-functional cleanup.
  • csrc/multidevice/ipc_handle.h — constructor declaration still uses the old parameter name output_buffer.

Important Files Changed

Filename Overview
csrc/host_ir/evaluator.cpp Cache key for Reduce operations now unconditionally uses input_tensor, since the input (not the output) is the symmetric buffer. Logic is correct; the two-step assignment could be simplified but is not incorrect.
csrc/multidevice/cuda_p2p.cpp Removes the redundant cudaMemcpyAsync that staged input into the symmetric buffer, and makes root-output-undefined an explicit error. Changes are correct and safe given the new symmetric-input assumption.
csrc/multidevice/ipc_handle.cpp Constructor now wraps the caller-provided tensor directly as the symmetric input buffer instead of allocating a new one. SymmetricTensor::validate() enforces correctness at runtime. Header declaration still uses the stale name output_buffer.
tests/cpp/test_multidevice_stream_parallel_type.cpp Correctly gates setMemoryType(MemoryType::Symmetric) behind kCuda backend check so NCCL path is unaffected.

Sequence Diagram

sequenceDiagram
    participant R0 as Rank 0 (root)
    participant Rn as Rank N (non-root)
    participant SM as Symmetric Memory (multicast)

    Note over R0,SM: Input tensors are already in symmetric memory (new assumption)

    R0->>SM: Write kInProgress to own semaphore
    Rn->>SM: Write kInProgress to own semaphore

    R0->>SM: Wait for all non-roots → kInProgress
    SM-->>R0: All non-roots ready

    R0->>SM: launchMulticastReduceKernel(multicastPtr → output)
    Note over R0: Reads from multicast ptr (all ranks' sym buffers)<br/>Writes result to output tensor (regular memory)

    R0->>SM: Signal kIdle to all non-root semaphores
    SM-->>Rn: kIdle received

    Note over R0,Rn: Old path had an extra cudaMemcpyAsync<br/>(input → sym buffer) before all of this
Loading

Comments Outside Diff (1)

  1. csrc/multidevice/ipc_handle.h, line 252 (link)

    P2 Stale parameter name in header declaration

    The constructor declaration in the header still uses the old parameter name output_buffer, but the implementation in ipc_handle.cpp now uses buffer. More importantly, output_buffer is now semantically misleading — the tensor passed here is the input (symmetric) buffer, not an output buffer.

Reviews (1): Last reviewed commit: "Lint" | Re-trigger Greptile

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 25, 2026

!test

Comment on lines 876 to 878
NVF_CHECK(
input.scalar_type() == at::kFloat &&
reduce_handle->inputBuffer().scalar_type() == at::kFloat &&
(!output.defined() || output.scalar_type() == at::kFloat),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better error reporting, can you split this into two NVF_CHECKs? The first one may use NVF_CHECK_EQ.

// For Broadcast and Reduce, non-roots may have no output; use input for
// cache key
at::Tensor cache_buffer =
output_tensor.defined() ? output_tensor : input_tensor;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to figure out the contract here since the code is getting a bit too if-elsy.

  1. When can input_tensor/output_tensor be undefined? When and only when the rank is not in the team?
  2. For the kCuda backend, broadcast's output has to be symmetric and reduce's input has to be symmetric?

(I'm trying to confirm my understanding I before proposing any cleanups -- thanks!)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants