Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,16 @@ void HostIrEvaluator::handle(MoeCombine* combine) {
auto n_tokens_from_rank =
getKnownConcreteValue(combine->inTokensFromRank()).as<at::Tensor>();

auto num_tokens =
expr_evaluator_.evaluate(combine->numTokens()).as<int64_t>();

auto result = doMoeCombine(
x,
topk_weights,
src_idx,
n_tokens_to_rank,
n_tokens_from_rank,
num_tokens,
communicator_,
combine->backend());

Expand Down
2 changes: 2 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,15 @@ MoeCombine::MoeCombine(
TensorView* in_src_idx,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
Val* num_tokens,
CommunicatorBackend backend)
: Expr(passkey) {
addInput(in_x);
addInput(in_topk_weights);
addInput(in_src_idx);
addInput(in_n_tokens_to_rank);
addInput(in_n_tokens_from_rank);
addInput(num_tokens);
addOutput(out_x);
addDataAttribute(backend);
validate();
Expand Down
12 changes: 12 additions & 0 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ class MoeCombine : public Expr {
TensorView* in_src_idx,
TensorView* in_n_tokens_to_rank,
TensorView* in_n_tokens_from_rank,
Val* num_tokens,
CommunicatorBackend backend = CommunicatorBackend::kNccl);

MoeCombine(const MoeCombine& other) = delete;
Expand Down Expand Up @@ -317,6 +318,17 @@ class MoeCombine : public Expr {
return input(4)->as<TensorView>();
}

//! Extent of the dispatch input's first axis (= original local
//! token count T).
//!
//! Used by the CUDA backend to size recv buffers and the output
//! without GPU-to-CPU sync. When pre-allocated outputs are
//! supported, the combine output TensorView's shape could be used
//! directly instead.
Val* numTokens() const {
return input(5);
}

CommunicatorBackend backend() const {
return attribute<CommunicatorBackend>(0);
}
Expand Down
22 changes: 5 additions & 17 deletions csrc/multidevice/cuda_p2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,25 +970,10 @@ void alltoallvWithCudaBackend(
const at::Tensor& send,
const at::Tensor& recv,
const AlltoallvMetadata& metadata,
const std::vector<void*>& recv_ptrs,
const at::Tensor& recv_ptrs_gpu,
CUstream stream) {
NVF_CHECK(send.is_cuda(), "alltoallv send must be CUDA.");
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");

auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());

const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
Comment on lines 1463 to 1466
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing validation for recv_ptrs_gpu — no size or device check

The old call-site accepted const std::vector<void*>& recv_ptrs and explicitly verified:

NVF_CHECK(
    (int64_t)recv_ptrs.size() == metadata.world_size,
    "recv_ptrs size must match world size.");

It also coerced the pointer table to the send device via .to(send.device()).

The new at::Tensor recv_ptrs_gpu has neither check: if it has fewer than world_size entries the kernel silently reads garbage pointers; if it lives on the wrong device the launch will fault. remotePointersTensor() always produces a [world_size] tensor on the right device by construction, but the API contract is now implicit and fragile for any future caller. Consider adding:

Suggested change
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
(int64_t)recv_ptrs.size() == metadata.world_size,
"recv_ptrs size must match world size.");
auto cpu_options = at::TensorOptions().dtype(at::kLong).device(at::kCPU);
auto recv_ptrs_cpu = at::empty({metadata.world_size}, cpu_options);
auto* ptrs = recv_ptrs_cpu.data_ptr<int64_t>();
for (int64_t rank = 0; rank < metadata.world_size; ++rank) {
ptrs[rank] =
static_cast<int64_t>(reinterpret_cast<uintptr_t>(recv_ptrs[rank]));
}
auto recv_ptrs_cuda = recv_ptrs_cpu.to(send.device());
const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;
NVF_CHECK(
metadata.max_send_total == 0 ||
send.numel() % metadata.max_send_total == 0,
NVF_CHECK(recv.is_cuda(), "alltoallv recv must be CUDA.");
NVF_CHECK(
recv_ptrs_gpu.is_cuda() && recv_ptrs_gpu.device() == send.device(),
"recv_ptrs_gpu must be a CUDA tensor on the same device as send.");
NVF_CHECK(
recv_ptrs_gpu.dim() == 1 &&
recv_ptrs_gpu.size(0) == metadata.world_size,
"recv_ptrs_gpu must have shape [world_size].");

Expand All @@ -997,6 +982,9 @@ void alltoallvWithCudaBackend(
metadata.max_recv == 0 || recv.numel() % metadata.max_recv == 0,
"alltoallv recv numel must be divisible by max_recv.");

const int64_t elem_stride =
metadata.max_send_total > 0 ? send.numel() / metadata.max_send_total : 1;

auto send_offsets = metadata.send_offsets;
auto send_counts = metadata.send_counts;
auto recv_offsets = metadata.recv_offsets;
Expand All @@ -1010,7 +998,7 @@ void alltoallvWithCudaBackend(

launchAlltoallvKernel(
send.data_ptr(),
reinterpret_cast<const uint64_t*>(recv_ptrs_cuda.data_ptr<int64_t>()),
reinterpret_cast<const uint64_t*>(recv_ptrs_gpu.data_ptr<int64_t>()),
send_offsets.data_ptr<int64_t>(),
send_counts.data_ptr<int64_t>(),
recv_offsets.data_ptr<int64_t>(),
Expand Down
15 changes: 10 additions & 5 deletions csrc/multidevice/cuda_p2p.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ struct AlltoallvMetadata {
at::Tensor recv_counts; // CUDA [R]
at::Tensor send_offsets; // CUDA [R]
at::Tensor recv_offsets; // CUDA [R]
int64_t total_recv = 0;
int64_t max_recv = 0;
int64_t max_send_total = 0;
int64_t max_send_bytes = 0;

// CPU scalars — upper bounds from the caller, NOT read from GPU.
// Using upper bounds (instead of exact GPU values) avoids CPU-GPU
// sync and keeps the data path CUDA-graph-capturable.
int64_t total_recv = 0; // upper bound on sum(recv_counts)
int64_t max_recv = 0; // recv buffer first dim
int64_t max_send_total = 0; // send buffer first dim = sum(send_counts)
int64_t max_send_bytes = 0; // max per-peer send count (kernel grid X)
int64_t world_size = 0;
};

Expand All @@ -64,7 +68,8 @@ void alltoallvWithCudaBackend(
const at::Tensor& send,
const at::Tensor& recv,
const AlltoallvMetadata& metadata,
const std::vector<void*>& recv_ptrs,
const at::Tensor& recv_ptrs_gpu, // CUDA [R] int64, from
// SymmetricTensor::remotePointersTensor
CUstream stream);

void alltoallvBarrier(const std::string& tag);
Expand Down
Loading
Loading