Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ void HostIrEvaluator::handle(MoeCombine* combine) {
src_idx,
n_tokens_to_rank,
n_tokens_from_rank,
combine->numTokens(),
communicator_,
combine->backend());

Expand Down
2 changes: 2 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ MoeCombine::MoeCombine(
TensorView* in_src_idx,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
int64_t num_tokens,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
Expand All @@ -304,6 +305,7 @@ MoeCombine::MoeCombine(
addInput(in_n_tokens_to_rank);
addInput(in_n_tokens_from_rank);
addOutput(out_x);
addDataAttribute(num_tokens);
addDataAttribute(backend);
validate();
}
Expand Down
10 changes: 9 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class MoeCombine : public Expr {
TensorView* in_src_idx,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
int64_t num_tokens,
CommunicatorBackend backend = CommunicatorBackend::kNccl);

MoeCombine(const MoeCombine& other) = delete;
Expand Down Expand Up @@ -317,8 +318,15 @@ class MoeCombine : public Expr {
return input(4)->as<TensorView>();
}

//! Original local token count from before dispatch (= x.size(0)
//! in the dispatch call). Used by the CUDA backend to size the
//! combine output without any GPU-to-CPU sync.
int64_t numTokens() const {
return attribute<int64_t>(0);
}

CommunicatorBackend backend() const {
return attribute<CommunicatorBackend>(0);
return attribute<CommunicatorBackend>(1);
}

private:
Expand Down
24 changes: 3 additions & 21 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,32 +970,13 @@ void alltoallvWithCudaBackend(
const at::Tensor& send,
const at::Tensor& recv,
const AlltoallvMetadata& metadata,
const std::vector<void*>& recv_ptrs,
const at::Tensor& recv_ptrs_gpu,
CUstream stream) {
NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA.");
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");

auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());

const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
Comment on lines 1472 to 1473
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Divisibility guard removed — silent wrong elem_stride if mismatched

The PR removes the checks:

NVF_CHECK(
    metadata.max_send_total == 0 ||
        send.numel() % metadata.max_send_total == 0, ...);
NVF_CHECK(
    metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0, ...);

elem_stride is computed as send.numel() / metadata.max_send_total. If send.numel() is not divisible by max_send_total (e.g. because a caller passes mismatched metadata), integer truncation silently gives a wrong stride. Every send_offsets, send_counts, and recv_offsets is then scaled by this wrong value before being passed to the kernel, producing corrupted data without any error. The checks were cheap and provided essential diagnostic value; removing them for graph-capturability does not improve performance because they are CPU-side and never execute inside a captured region.

NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
"alltoallv send numel must be divisible by max_send_total.");
NVF_CHECK(
metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0,
"alltoallv recv numel must be divisible by max_recv.");

auto send_offsets = metadata.send_offsets;
auto send_counts = metadata.send_counts;
Expand All @@ -1010,7 +991,8 @@ void alltoallvWithCudaBackend(

launchAlltoallvKernel(
send.data_ptr(),
reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()),
reinterpret_cast<const uint64_t*>(
recv_ptrs_gpu.data_ptr<int64_t>()),
send_offsets.data_ptr<int64_t>(),
send_counts.data_ptr<int64_t>(),
recv_offsets.data_ptr<int64_t>(),
Expand Down
7 changes: 6 additions & 1 deletion csrc/multidevice/cuda_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ struct AlltoallvMetadata {
at::Tensor recv_counts; // CUDA [R]
at::Tensor send_offsets; // CUDA [R]
at::Tensor recv_offsets; // CUDA [R]

// CPU scalars — upper bounds from the caller, NOT read from GPU.
// Using upper bounds (instead of exact GPU values) avoids CPU-GPU
// sync and keeps the data path CUDA-graph-capturable.
int64_t total_recv = 0;
int64_t max_recv = 0;
int64_t max_send_total = 0;
Expand All @@ -64,7 +68,8 @@ void alltoallvWithCudaBackend(
const at::Tensor& send,
const at::Tensor& recv,
const AlltoallvMetadata& metadata,
const std::vector<void*>& recv_ptrs,
const at::Tensor& recv_ptrs_gpu, // CUDA [R] int64, from
// SymmetricTensor::remotePointersTensor
CUstream stream);

void alltoallvBarrier(const std::string& tag);
Expand Down
Loading
Loading