Remove extra copy by assuming reduce tensor is symmetric#6040
Remove extra copy by assuming reduce tensor is symmetric#6040nsarka wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR removes an unnecessary Key changes:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant R0 as Rank 0 (root)
participant Rn as Rank N (non-root)
participant SM as Symmetric Memory (multicast)
Note over R0,SM: Input tensors are already in symmetric memory (new assumption)
R0->>SM: Write kInProgress to own semaphore
Rn->>SM: Write kInProgress to own semaphore
R0->>SM: Wait for all non-roots → kInProgress
SM-->>R0: All non-roots ready
R0->>SM: launchMulticastReduceKernel(multicastPtr → output)
Note over R0: Reads from multicast ptr (all ranks' sym buffers)<br/>Writes result to output tensor (regular memory)
R0->>SM: Signal kIdle to all non-root semaphores
SM-->>Rn: kIdle received
Note over R0,Rn: Old path had an extra cudaMemcpyAsync<br/>(input → sym buffer) before all of this
|
|
!test |
| NVF_CHECK( | ||
| input.scalar_type() == at::kFloat && | ||
| reduce_handle->inputBuffer().scalar_type() == at::kFloat && | ||
| (!output.defined() || output.scalar_type() == at::kFloat), |
There was a problem hiding this comment.
For better error reporting, can you split this into two NVF_CHECKs? The first one may use NVF_CHECK_EQ.
| // For Broadcast and Reduce, non-roots may have no output; use input for | ||
| // cache key | ||
| at::Tensor cache_buffer = | ||
| output_tensor.defined() ? output_tensor : input_tensor; |
There was a problem hiding this comment.
I'm trying to figure out the contract here since the code is getting a bit too if-elsy.
- When can input_tensor/output_tensor be undefined? When and only when the rank is not in the team?
- For the kCuda backend, broadcast's output has to be symmetric and reduce's input has to be symmetric?
(I'm trying to confirm my understanding I before proposing any cleanups -- thanks!)
For the cuda backend. I will update the PR shortly to do the same for allreduce.