Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions csrc/multidevice/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#include <numeric>

#ifdef NVFUSER_DISTRIBUTED
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
#include <torch/csrc/distributed/c10d/PrefixStore.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/exception.h>
#ifdef USE_C10D_NCCL
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
Expand Down Expand Up @@ -121,7 +123,8 @@ bool parseEnv(
}

// retrieves master port
if ((env = std::getenv("NVFUSER_MASTER_PORT")) != nullptr) {
env = std::getenv("NVFUSER_MASTER_PORT");
if (env != nullptr) {
master_port = std::atoi(env);
} else {
LOG(INFO) << "The environment variable NVFUSER_MASTER_PORT has not been "
Expand Down Expand Up @@ -248,10 +251,10 @@ void waitForDebuggerAtRanks(
std::cerr << "Process " << pid
<< " is waiting for the debugger. To continue debugging, "
<< "start gdb, `attach " << pid
<< "`, `set var waiting=false`, and `fini`." << std::endl;
<< "`, `set var waiting=false`, and `fini`.\n";
while (waiting) { // Please change `waiting` in the debugger.
}
std::cerr << "Process " << getpid() << " finished waiting." << std::endl;
std::cerr << "Process " << getpid() << " finished waiting.\n";
}

if (communicator->is_available()) {
Expand Down Expand Up @@ -349,19 +352,25 @@ void Communicator::cleanup() {

store_ = nullptr;

#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
#if defined(NVFUSER_DISTRIBUTED)
#if defined(USE_C10D_NCCL)
// Sort backends to work around a NCCL bug (nvbugs/4889623). Closing backends
// in different orders between ranks have been causing a hang.
std::vector<std::pair<std::string, c10::intrusive_ptr<c10d::Backend>>>
keyed_backends(backends_.begin(), backends_.end());
std::sort(keyed_backends.begin(), keyed_backends.end());
std::ranges::sort(keyed_backends.begin(), keyed_backends.end());
for (auto& [key, backend] : keyed_backends) {
// Call shutdown before destructing a ProcessGroupNCCL as instructed by
// https://github.com/pytorch/pytorch/blob/e62073d7997c9e63896cb5289ffd0874a8cc1838/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L1164-L1170.
if (auto* pg_nccl = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get())) {
pg_nccl->shutdown();
}
}
#endif
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
Comment on lines +370 to +373
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 process_groups_ cleanup guard mismatch — compile error when NVFUSER_DISTRIBUTED is set without USE_DISTRIBUTED

process_groups_ is declared in communicator.h under #if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED), but the cleanup loop here lives under the broader #if defined(NVFUSER_DISTRIBUTED) (without the USE_DISTRIBUTED guard). When a build defines NVFUSER_DISTRIBUTED but not USE_DISTRIBUTED, process_groups_ does not exist as a member, yet this code tries to iterate over it — a hard compile error.

c10d::unregister_process_group (from GroupRegistry.hpp) is already included under #ifdef NVFUSER_DISTRIBUTED, so fixing just the guard on these lines is sufficient:

Suggested change
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
#if defined(USE_DISTRIBUTED)
for (const auto& entry : process_groups_) {
c10d::unregister_process_group(entry.first);
}
process_groups_.clear();
#endif

(The surrounding #if defined(NVFUSER_DISTRIBUTED) / #endif already provides the outer distributed guard.)

#endif
backends_.clear();
}
Expand All @@ -388,7 +397,7 @@ c10d::Backend* Communicator::getBackendForTeam(
#ifdef NVFUSER_DISTRIBUTED
backends_[team_key] = [&]() -> c10::intrusive_ptr<c10d::Backend> {
// check that the caller's rank belongs to the requested team
auto rank_it = std::find(team.begin(), team.end(), deviceId());
auto rank_it = std::ranges::find(team.begin(), team.end(), deviceId());
if (rank_it == team.end()) {
return nullptr;
}
Expand All @@ -402,6 +411,28 @@ c10d::Backend* Communicator::getBackendForTeam(
}();
#else
backends_[team_key] = nullptr;
#endif
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED)
std::optional<c10d::ProcessGroup::BackendType> pg_backend =
(b == CommunicatorBackend::kNccl)
? std::optional<c10d::ProcessGroup::BackendType>(
c10d::ProcessGroup::BackendType::NCCL)
: std::nullopt;
if (backends_[team_key] != nullptr && pg_backend.has_value()) {
auto rank_it = std::ranges::find(team.begin(), team.end(), deviceId());
RankType team_rank = std::distance(team.begin(), rank_it);

auto pg = c10::make_intrusive<c10d::ProcessGroup>(
c10::make_intrusive<c10d::PrefixStore>(team_key, store_),
team_rank,
static_cast<int>(team.size()));
pg->setBackend(c10::DeviceType::CUDA, *pg_backend, backends_[team_key]);
pg->setDefaultBackend(*pg_backend);
pg->setGroupName(team_key);

c10d::register_process_group(team_key, pg);
process_groups_[team_key] = std::move(pg);
}
Comment on lines +415 to +435
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ProcessGroup wrapper only created on first backend creation — silently missing after getWorld()

The ProcessGroup wrapper is registered inside the backends_.find(team_key) == backends_.end() guard, meaning it is only created the first time getBackendForTeam is called for a given team_key. If getWorld() (or any early comm.barrier()) is called before the first PyTorch symmetric memory operation, the NCCL backend gets created and cached with no ProcessGroup wrapper. When getSymmMemGroupKey subsequently calls getBackendForTeam, it finds team_key already in backends_ and returns early — no ProcessGroup is created and c10d::resolve_process_group(group_name) inside ensurePyTorchSymmMemBackend will throw.

The ProcessGroup registration should not be gated solely on first-time backend creation. Consider also checking process_groups_.count(team_key) == 0 so the wrapper is created even when the backend already exists:

if (backends_[team_key] != nullptr && pg_backend.has_value()
    && process_groups_.count(team_key) == 0) {
  // ... create and register ProcessGroup
}

#endif
}
return backends_.at(team_key).get();
Expand All @@ -424,4 +455,13 @@ void Communicator::barrier(std::optional<CommunicatorBackend> backend) {
getWorld(backend)->barrier(options)->wait();
}

std::string Communicator::getSymmMemGroupKey(
std::optional<CommunicatorBackend> backend) {
std::vector<RankType> all_ranks(size_);
std::iota(all_ranks.begin(), all_ranks.end(), 0);
CommunicatorBackend b = backend.value_or(default_backend_);
(void)getBackendForTeam(all_ranks, b);
return getTeamKey(all_ranks, b);
}

} // namespace nvfuser
12 changes: 11 additions & 1 deletion csrc/multidevice/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
#include <ATen/core/ivalue.h>
#include <c10/util/intrusive_ptr.h>

#ifdef NVFUSER_DISTRIBUTED
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED)
#include <torch/csrc/distributed/c10d/Backend.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this header should always be present, no?

#include <torch/csrc/distributed/c10d/TCPStore.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#else
Expand Down Expand Up @@ -110,6 +111,10 @@ class NVF_API Communicator {
c10d::Backend* getWorld(
std::optional<CommunicatorBackend> backend = std::nullopt);

// Returns the world process-group name for the given backend.
std::string getSymmMemGroupKey(
std::optional<CommunicatorBackend> backend = std::nullopt);

// returns if a backend is available for creation
bool isBackendAvailable(CommunicatorBackend backend) const {
if (backend == CommunicatorBackend::kUcc) {
Expand Down Expand Up @@ -153,6 +158,11 @@ class NVF_API Communicator {
c10::intrusive_ptr<c10d::TCPStore> store_;
// cache for the created backends. The keys are strings generated from Teams
std::unordered_map<std::string, c10::intrusive_ptr<c10d::Backend>> backends_;
// c10d process-group wrappers registered for symmetric-memory rendezvous.
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_DISTRIBUTED)
std::unordered_map<std::string, c10::intrusive_ptr<c10d::ProcessGroup>>
process_groups_;
#endif
};

} // namespace nvfuser
46 changes: 34 additions & 12 deletions csrc/multidevice/ipc_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int createIpcSocket(const std::string& path) {
int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
NVF_CHECK(sockfd >= 0, "Failed to create socket: ", strerror(errno));

struct sockaddr_un addr;
struct sockaddr_un addr{};
setupSockAddr(addr, path);

// For abstract namespace, len is usually calculated specifically, but for
Expand Down Expand Up @@ -69,31 +69,34 @@ void sendFd(
int sockfd = socket(AF_UNIX, SOCK_STREAM, 0);
NVF_CHECK(sockfd >= 0, "Failed to create socket: ", strerror(errno));

struct sockaddr_un addr;
struct sockaddr_un addr{};
setupSockAddr(addr, path);
socklen_t addrlen = sizeof(addr.sun_family) + path.length();

// Simple retry loop for connection
int ret = -1;
for (int i = 0; i < 100; ++i) {
ret = connect(sockfd, (struct sockaddr*)&addr, addrlen);
if (ret == 0)
if (ret == 0) {
break;
}
usleep(10000); // 10ms
}
if (ret < 0) {
close(sockfd);
NVF_CHECK(false, "Failed to connect to ", path, ": ", strerror(errno));
}

struct msghdr msg = {0};
struct cmsghdr* cmsg;
struct msghdr msg{};
struct cmsghdr* cmsg = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
char buf[CMSG_SPACE(sizeof(int))];

// If no header data, send at least one byte
char dummy = '.';
struct iovec iov;
struct iovec iov{};
if (header_data && header_len > 0) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
iov.iov_base = const_cast<void*>(header_data);
iov.iov_len = header_len;
} else {
Expand Down Expand Up @@ -121,21 +124,22 @@ void sendFd(
}

int recvFd(int socket_fd, void* header_data, size_t header_len) {
struct sockaddr_un client_addr;
struct sockaddr_un client_addr{};
socklen_t client_len = sizeof(client_addr);
int client_fd =
accept(socket_fd, (struct sockaddr*)&client_addr, &client_len);
NVF_CHECK(client_fd >= 0, "Failed to accept connection: ", strerror(errno));

struct msghdr msg = {0};
struct cmsghdr* cmsg;
struct msghdr msg{};
struct cmsghdr* cmsg = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
char buf[CMSG_SPACE(sizeof(int))];

// If header_len > 0, we expect that much data.
// Note: recvmsg might return fewer bytes if strict requirements aren't met,
// but for local unix sockets with small payloads, it usually delivers all.
char dummy;
struct iovec iov;
char dummy = '.';
struct iovec iov{};
if (header_data && header_len > 0) {
iov.iov_base = header_data;
iov.iov_len = header_len;
Expand Down Expand Up @@ -168,7 +172,7 @@ int recvFd(int socket_fd, void* header_data, size_t header_len) {

int recv_fd = -1;
cmsg = CMSG_FIRSTHDR(&msg);
if (cmsg != NULL && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg != nullptr && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) {
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
memcpy(&recv_fd, CMSG_DATA(cmsg), sizeof(int));
}
Expand All @@ -191,4 +195,22 @@ MulticastProtocol getMulticastProtocol() {
return MulticastProtocol::BatchMemcpy;
}

SymmetricMemoryBackend getSymmetricMemoryBackend() {
if (isOptionEnabled(EnableOption::SymmetricMemoryBackend)) {
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nccl")) {
return SymmetricMemoryBackend::PyTorchNccl;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_nvshmem")) {
return SymmetricMemoryBackend::PyTorchNvshmem;
}
if (hasEnableOptionArgument(
EnableOption::SymmetricMemoryBackend, "pytorch_cuda")) {
return SymmetricMemoryBackend::PyTorchCuda;
}
}
return SymmetricMemoryBackend::Native;
}

} // namespace nvfuser
16 changes: 15 additions & 1 deletion csrc/multidevice/ipc_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,24 @@ const T& fromBytes(const std::vector<uint8_t>& bytes) {

// IPC Utils for sharing file descriptors

enum class MulticastProtocol { Memcpy, Multimem, BatchMemcpy };
enum class MulticastProtocol : uint8_t { Memcpy, Multimem, BatchMemcpy };

MulticastProtocol getMulticastProtocol();

// Backend for symmetric memory allocation and rendezvous.
// Native: Fuser's own CUDA VMM + IPC implementation (default, maintained).
// PyTorch*: Use PyTorch's symmetric memory
// (torch.distributed._symmetric_memory) with the given transport backend (Nccl,
// Nvshmem, or Cuda).
enum class SymmetricMemoryBackend : uint8_t {
Native,
PyTorchNccl,
PyTorchNvshmem,
PyTorchCuda,
};

SymmetricMemoryBackend getSymmetricMemoryBackend();

// Creates a listening Unix domain socket bound to path.
// If path starts with '@', it uses the abstract namespace (replaced with \0).
// Returns the socket file descriptor.
Expand Down
Loading
Loading