Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,15 @@ void HostIrEvaluator::handle(Communication* communication) {
communication->type() == CommunicationType::Reduce)
? root_val
: -1;
// For Reduce, non-roots may have no output; use input for cache key
// For Broadcast and Reduce, non-roots may have no output; use input for
// cache key
at::Tensor cache_buffer =
output_tensor.defined() ? output_tensor : input_tensor;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to figure out the contract here since the code is getting a bit too if-elsy.

  1. When can input_tensor/output_tensor be undefined? When and only when the rank is not in the team?
  2. For the kCuda backend, broadcast's output has to be symmetric and reduce's input has to be symmetric?

(I'm trying to confirm my understanding I before proposing any cleanups -- thanks!)

// For Reduce specifically, use the input tensor for the cache key
// since the output doesn't need to be symmetric
if (communication->type() == CommunicationType::Reduce) {
cache_buffer = input_tensor;
}
SymmetricMemoryHandle* multicast_handle = multicast_handle_cache_.get(
{.buffer = cache_buffer, .expr = communication, .root = cache_root});
postWithCudaBackend(
Expand Down
19 changes: 5 additions & 14 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,6 @@ void waitAllreduceWithCudaBackend(

void postReduceWithCudaBackend(
Communication* communication,
at::Tensor input,
at::Tensor output,
SymmetricMemoryForReduce* reduce_handle,
CUstream stream,
Expand All @@ -875,7 +874,7 @@ void postReduceWithCudaBackend(
"Only SUM reduction is supported for multimem reduce; got ",
communication->reduceOp());
NVF_CHECK(
input.scalar_type() == at::kFloat &&
reduce_handle->inputBuffer().scalar_type() == at::kFloat &&
(!output.defined() || output.scalar_type() == at::kFloat),
Comment on lines 876 to 878
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better error reporting, can you split this into two NVF_CHECKs? The first one may use NVF_CHECK_EQ.

"Only float32 is supported for multimem reduce.");

Expand All @@ -886,14 +885,6 @@ void postReduceWithCudaBackend(
"size=",
size);

// Copy input to symmetric buffer
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpyAsync(
reduce_handle->inputBuffer().data_ptr(),
input.data_ptr(),
size,
cudaMemcpyDeviceToDevice,
stream));

const int64_t world_size = communicator.size();
// All ranks signal ready by writing kInProgress to their own semaphore
NVFUSER_CUDA_SAFE_CALL(cuStreamWriteValue32(
Expand Down Expand Up @@ -923,10 +914,10 @@ void postReduceWithCudaBackend(
cuStreamBatchMemOp(stream, world_size - 1, ops.data(), 0));

// Root launches the ld_reduce kernel
void* dst = output.defined() ? output.data_ptr()
: reduce_handle->inputBuffer().data_ptr();
NVF_ERROR(
output.defined(), "Output must be defined for reduce on the root");
launchMulticastReduceKernel(
reduce_handle->multicastPtr(), dst, size, stream);
reduce_handle->multicastPtr(), output.data_ptr(), size, stream);

// Root signals completion by writing kIdle to all non-root semaphores
std::vector<CUstreamBatchMemOpParams> write_complete_ops(world_size - 1);
Expand Down Expand Up @@ -1267,7 +1258,7 @@ void postWithCudaBackend(
dynamic_cast<SymmetricMemoryForReduce*>(symmetric_memory_handle);
NVF_ERROR(reduce_handle != nullptr, "Invalid reduce handle");
postReduceWithCudaBackend(
communication, input, output, reduce_handle, stream, root);
communication, output, reduce_handle, stream, root);
break;
}
default:
Expand Down
13 changes: 5 additions & 8 deletions csrc/multidevice/ipc_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,14 @@ void* SymmetricMemoryForAllreduce::semaphoreUnicastPtr(
SymmetricMemoryForReduce::SymmetricMemoryForReduce(
Communication* communication,
int64_t root,
at::Tensor output_buffer)
: size_bytes_(output_buffer.numel() * output_buffer.element_size()) {
at::Tensor buffer)
: size_bytes_(buffer.numel() * buffer.element_size()) {
std::string name_suffix =
"for_Communication" + std::to_string(communication->name());
std::string store_key_prefix = "nvls_reduce_" + name_suffix;

at::Tensor input_sym = SymmetricTensor::allocate(
output_buffer.sizes(),
output_buffer.scalar_type(),
output_buffer.device());
input_sym_tensor_ = std::make_unique<SymmetricTensor>(input_sym);
// We assume the input buffer is already a symmetric tensor
input_sym_tensor_ = std::make_unique<SymmetricTensor>(buffer);
input_sym_tensor_->setupRemoteHandles(store_key_prefix + "_input_unicast");

MulticastProtocol protocol = getMulticastProtocol();
Expand All @@ -303,7 +300,7 @@ SymmetricMemoryForReduce::SymmetricMemoryForReduce(
at::Tensor semaphore = SymmetricTensor::allocate(
/*sizes=*/at::IntArrayRef({1}),
/*dtype=*/at::ScalarType::Int,
/*device=*/output_buffer.device());
/*device=*/buffer.device());

// Initialize the semaphore to kIdle
IpcSemaphore init_value = IpcSemaphore::kIdle;
Expand Down
6 changes: 6 additions & 0 deletions tests/cpp/test_multidevice_stream_parallel_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,12 @@ TEST_P(RSMatmulTest, ReduceScatterReduceBased) {
C->axis(1)->parallelize(ParallelType::DIDx);
C->axis(0)->parallelize(ParallelType::Stream);

// Set the matmul's output to be symmetric in order to use NVLS multimem
// reduce.
if (communicator_backend == CommunicatorBackend::kCuda) {
C_unreduced->setMemoryType(MemoryType::Symmetric);
}

MultiDeviceExecutorParams params;
params.lower.communicator_backend = communicator_backend;
params.lower.offset_stream_indexing_by_rank = false; // Will fail if true
Expand Down
Loading