Skip to content

Remove extra copy by assuming reduce tensor is symmetric#6040

Open
nsarka wants to merge 2 commits intoNVIDIA:mainfrom
nsarka:nsarka/nvls-buffer
Open

Remove extra copy by assuming reduce tensor is symmetric#6040
nsarka wants to merge 2 commits intoNVIDIA:mainfrom
nsarka:nsarka/nvls-buffer

Conversation

@nsarka
Copy link
Copy Markdown
Member

@nsarka nsarka commented Mar 25, 2026

For the cuda backend. I will update the PR shortly to do the same for allreduce.

@nsarka nsarka requested review from samnordmann and wujingyue March 25, 2026 21:42
@nsarka nsarka self-assigned this Mar 25, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 25, 2026

Greptile Summary

This PR removes an unnecessary device-to-device copy in the CUDA backend's Reduce collective by requiring the input tensor to already reside in symmetric (VMM-backed) memory, which enables the multimem ld_reduce kernel to read directly from the multicast VA without first staging data into a separate symmetric buffer. The optimizer-visible change is that callers must now annotate the reduce-input TensorView with MemoryType::Symmetric when using the CUDA backend.

Key changes:

  • SymmetricMemoryForReduce now wraps the caller-supplied buffer directly instead of allocating a new symmetric tensor and scheduling a cudaMemcpyAsync copy into it.
  • postReduceWithCudaBackend drops the input parameter and the copy, tightens the root assertion to require output to be defined, and reads scalar type from the handle's buffer.
  • The handle(Communication*) post path in the evaluator correctly switches the cache key to input_tensor for Reduce so that the symmetric handle is keyed by the symmetric input.
  • However, the matching handle(Wait* wait) path was not updated. On the root rank, the wait handler resolves cache_buffer to output_tensor (which is defined for root), while the post handler used input_tensor. SymmetricMemoryHandleCache::get() will not find the existing entry and will attempt to construct a new SymmetricMemoryForReduce from the non-symmetric output tensor, which immediately fails inside SymmetricTensor::validate(). The root rank will crash at wait time for every Reduce with the CUDA backend.

Confidence Score: 4/5

Not safe to merge as-is: the root rank will crash at the Wait phase due to a missing cache-key fix in the wait handler.

The core optimization in cuda_p2p.cpp and ipc_handle.cpp is correct and clean. The evaluator post path is also correctly updated. However, the wait handler in evaluator.cpp was not updated to mirror the post path's Reduce-specific cache-key logic, causing the root rank to look up a different cache entry during Wait and triggering a runtime NVF_CHECK failure. Score 4 reflects one clear P1 defect that must be resolved before merging.

csrc/host_ir/evaluator.cpp — the Wait handler (around line 527) needs the same if (Reduce) { cache_buffer = input_tensor; } guard that was added to the post path.

Important Files Changed

Filename Overview
csrc/host_ir/evaluator.cpp Post path correctly updated to use input_tensor as cache key for Reduce, but the wait handler is missing the same fix — the root rank will crash at wait time when the cache tries to construct a SymmetricMemoryForReduce from a non-symmetric output tensor.
csrc/multidevice/cuda_p2p.cpp Removes the cudaMemcpyAsync copy from input to symmetric buffer and the now-unnecessary input parameter from postReduceWithCudaBackend; also tightens root assertion to require output to be defined — all changes are correct and consistent.
csrc/multidevice/ipc_handle.cpp SymmetricMemoryForReduce constructor now wraps the caller-supplied (already symmetric) buffer directly instead of allocating a new one and copying into it; SymmetricTensor::validate() enforces the symmetric-memory requirement at construction time.
tests/cpp/test_multidevice_stream_parallel_type.cpp Adds setMemoryType(MemoryType::Symmetric) on C_unreduced for the CUDA backend to satisfy the new requirement that the reduce input be a symmetric tensor; non-CUDA path is unchanged.

Sequence Diagram

sequenceDiagram
    participant E as HostIrEvaluator
    participant C as HandleCache
    participant R as SymMemForReduce
    participant K as CUDA Reduce Kernel

    note over E: Post phase (handle Communication)
    E->>C: get(input_tensor, comm, root)
    C->>R: construct with input_tensor (symmetric)
    R-->>C: handle created
    C-->>E: reduce_handle ptr

    note over K: Before this PR: cudaMemcpyAsync input to sym_buf (removed)
    E->>K: postReduceWithCudaBackend(output, handle, stream, root)
    K->>K: All ranks signal kInProgress
    K->>K: Root waits for all non-roots
    K->>K: launchMulticastReduceKernel via mc_ptr to output
    K->>K: Root signals kIdle to non-roots

    note over E: Wait phase (handle Wait) - BUG on root rank
    E->>C: get(output_tensor, comm, root) - mismatched key on root
    C->>R: construct with output_tensor (NOT symmetric)
    R-->>C: SymmetricTensor validate FAILS - crash
Loading

Comments Outside Diff (1)

  1. csrc/host_ir/evaluator.cpp, line 527-530 (link)

    Wait handler missing Reduce cache-key fix

    The handle(Wait* wait) method still uses the old cache-key logic for Reduce operations. For the root rank, output_tensor is defined, so cache_buffer is set to output_tensor here. But in handle(Communication*), the post path now always uses input_tensor as the cache key for Reduce (lines 385-387).

    The post and wait paths therefore look up two different cache entries for the same communication on the root rank. SymmetricMemoryHandleCache::get() does not find the entry created during post and tries to create a new SymmetricMemoryForReduce using output_tensor as the buffer. The SymmetricTensor constructor calls SymmetricTensor::validate(), which will fail with "Invalid symmetric allocation" because the output is a plain CUDA tensor, not a VMM-backed symmetric allocation. The root rank will crash at the wait phase for every Reduce with the CUDA backend. Non-root ranks are unaffected (they have no output tensor, so both paths resolve to input_tensor).

    The same fix applied in the post path needs to be mirrored here:

Reviews (2): Last reviewed commit: "Lint" | Re-trigger Greptile

@nsarka
Copy link
Copy Markdown
Member Author

nsarka commented Mar 25, 2026

!test

Comment on lines 876 to 878
NVF_CHECK(
input.scalar_type() == at::kFloat &&
reduce_handle->inputBuffer().scalar_type() == at::kFloat &&
(!output.defined() || output.scalar_type() == at::kFloat),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better error reporting, can you split this into two NVF_CHECKs? The first one may use NVF_CHECK_EQ.

// For Broadcast and Reduce, non-roots may have no output; use input for
// cache key
at::Tensor cache_buffer =
output_tensor.defined() ? output_tensor : input_tensor;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to figure out the contract here since the code is getting a bit too if-elsy.

  1. When can input_tensor/output_tensor be undefined? When and only when the rank is not in the team?
  2. For the kCuda backend, broadcast's output has to be symmetric and reduce's input has to be symmetric?

(I'm trying to confirm my understanding I before proposing any cleanups -- thanks!)

@nsarka nsarka force-pushed the nsarka/nvls-buffer branch from b57ce84 to e8d7f80 Compare March 30, 2026 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants