Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions comms/torchcomms/ncclx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Extension: torchcomms._comms_ncclx
file(GLOB TORCHCOMMS_NCCLX_SOURCES
"comms/torchcomms/ncclx/*.cpp"
"comms/torchcomms/transport/*.cc"
)
file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp")

Expand Down Expand Up @@ -61,7 +60,6 @@ add_library(torchcomms_comms_ncclx MODULE
${TORCHCOMMS_NCCLX_SOURCES}
${TORCHCOMMS_CUDA_API_SOURCE}
)
target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM)
set_target_properties(torchcomms_comms_ncclx PROPERTIES
PREFIX ""
OUTPUT_NAME "_comms_ncclx"
Expand Down
95 changes: 0 additions & 95 deletions comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp
Original file line number Diff line number Diff line change
@@ -1,114 +1,19 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <pybind11/chrono.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>

#include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp"
#include "comms/torchcomms/transport/RdmaTransport.h"

namespace py = pybind11;
using namespace torch::comms;

namespace {
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
// This intentionally creates and leaks a global event base thread to be used
// for all Transports on first use.
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
return scopedEventBaseThread;
}
} // namespace

PYBIND11_MODULE(_comms_ncclx, m) {
m.doc() = "NCCLX specific python bindings for TorchComm";

py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
m, "TorchCommNCCLX");

py::class_<RdmaRemoteBuffer, std::shared_ptr<RdmaRemoteBuffer>>(
m, "RdmaRemoteBuffer")
.def(
py::pickle(
[](const RdmaRemoteBuffer& buffer) { // __getstate__
return py::make_tuple(
reinterpret_cast<uintptr_t>(buffer.ptr), buffer.accessKey);
},
[](const py::tuple& t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error(
"Invalid state for RdmaRemoteBuffer");
}
return RdmaRemoteBuffer{
reinterpret_cast<void*>(t[0].cast<uintptr_t>()),
t[1].cast<std::string>()};
}));

py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
// initialize a new RDMATransport using a custom init fn
.def(py::init([](at::Device device) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
int cuda_device = device.index();
return std::make_shared<RdmaTransport>(
cuda_device, getScopedEventBaseThread().getEventBase());
}))
.def_static("supported", &RdmaTransport::supported)
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
.def(
"connect",
[](RdmaTransport& self, const py::bytes& peerUrl) {
std::string peerUrlStr = peerUrl.cast<std::string>();
return static_cast<int>(self.connect(peerUrlStr));
})
.def("connected", &RdmaTransport::connected)
.def(
"write",
[](RdmaTransport& self,
const RdmaMemory::View& localBuffer,
const RdmaRemoteBuffer& remoteBuffer) {
return static_cast<int>(
self.write(localBuffer, remoteBuffer, false).get());
})
.def(
"read",
[](RdmaTransport& self,
RdmaMemory::MutableView& localBuffer,
const RdmaRemoteBuffer& remoteBuffer) {
return static_cast<int>(self.read(localBuffer, remoteBuffer).get());
});

py::class_<RdmaMemory::View, std::shared_ptr<RdmaMemory::View>>(
m, "RdmaMemoryView")
.def("size", &RdmaMemory::View::size);

py::class_<RdmaMemory::MutableView, std::shared_ptr<RdmaMemory::MutableView>>(
m, "RdmaMemoryMutableView");

py::class_<RdmaMemory, std::shared_ptr<RdmaMemory>>(m, "RdmaMemory")
.def(py::init([](const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_contiguous(),
"RdmaMemory currently requires a contiguous tensor");
// If CPU memory is passed, use device 0 for NIC discovery
const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device();
return std::make_shared<RdmaMemory>(
tensor.data_ptr(), tensor.nbytes(), device);
}))
.def(
"to_view",
[](RdmaMemory& self) {
return self.createView(size_t(0), self.length());
})
.def(
"to_mutable_view",
[](RdmaMemory& self) {
return self.createMutableView(size_t(0), self.length());
})
.def("to_remote_buffer", [](RdmaMemory& self) {
return RdmaRemoteBuffer{
const_cast<void*>(self.data()), self.remoteKey()};
});
}
31 changes: 0 additions & 31 deletions comms/torchcomms/ncclx/_comms_ncclx.pyi
Original file line number Diff line number Diff line change
@@ -1,35 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# pyre-strict

import torch

class TorchCommNCCLX: ...

class RdmaMemoryView:
def size(self) -> int: ...

class RdmaMemoryMutableView: ...

class RdmaRemoteBuffer:
def __getstate__(self) -> tuple[int, str]: ...
def __setstate__(self, state: tuple[int, str]) -> None: ...

class RdmaMemory:
def __init__(self, tensor: torch.Tensor) -> None: ...
def to_view(self) -> RdmaMemoryView: ...
def to_mutable_view(self) -> RdmaMemoryMutableView: ...
def to_remote_buffer(self) -> RdmaRemoteBuffer: ...

class RdmaTransport:
def __init__(self, device: torch.device) -> None: ...
@staticmethod
def supported() -> bool: ...
def bind(self) -> bytes: ...
def connect(self, peer_url: bytes) -> int: ...
def connected(self) -> bool: ...
def write(
self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer
) -> int: ...
def read(
self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer
) -> int: ...
110 changes: 110 additions & 0 deletions comms/torchcomms/transport/RdmaTransportPy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.

#include <folly/io/async/EventBase.h>
#include <folly/io/async/ScopedEventBaseThread.h>
#include <pybind11/chrono.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/csrc/utils/pybind.h>

#include "comms/torchcomms/transport/RdmaTransport.h"

using namespace torch::comms;

namespace {
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
// This intentionally creates and leaks a global event base thread to be used
// for all Transports on first use.
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
return scopedEventBaseThread;
}
} // namespace

PYBIND11_MODULE(_transport, m) {
m.doc() = "RdmaTransport python bindings for TorchComm";

py::class_<RdmaRemoteBuffer, std::shared_ptr<RdmaRemoteBuffer>>(
m, "RdmaRemoteBuffer")
.def(
py::pickle(
[](const RdmaRemoteBuffer& buffer) { // __getstate__
return py::make_tuple(
reinterpret_cast<uintptr_t>(buffer.ptr), buffer.accessKey);
},
[](const py::tuple& t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error(
"Invalid state for RdmaRemoteBuffer");
}
return RdmaRemoteBuffer{
// NOLINTNEXTLINE(performance-no-int-to-ptr)
reinterpret_cast<void*>(t[0].cast<uintptr_t>()),
t[1].cast<std::string>()};
}));

py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
// initialize a new RDMATransport using a custom init fn
.def(py::init([](at::Device device) {
TORCH_INTERNAL_ASSERT(device.is_cuda());
int cuda_device = device.index();
return std::make_shared<RdmaTransport>(
cuda_device, getScopedEventBaseThread().getEventBase());
}))
.def_static("supported", &RdmaTransport::supported)
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
.def(
"connect",
[](RdmaTransport& self, const py::bytes& peerUrl) {
std::string peerUrlStr = peerUrl.cast<std::string>();
return static_cast<int>(self.connect(peerUrlStr));
})
.def("connected", &RdmaTransport::connected)
.def(
"write",
[](RdmaTransport& self,
const RdmaMemory::View& localBuffer,
const RdmaRemoteBuffer& remoteBuffer) {
return static_cast<int>(
self.write(localBuffer, remoteBuffer, false).get());
})
.def(
"read",
[](RdmaTransport& self,
RdmaMemory::MutableView& localBuffer,
const RdmaRemoteBuffer& remoteBuffer) {
return static_cast<int>(self.read(localBuffer, remoteBuffer).get());
});

py::class_<RdmaMemory::View, std::shared_ptr<RdmaMemory::View>>(
m, "RdmaMemoryView")
.def("size", &RdmaMemory::View::size);

py::class_<RdmaMemory::MutableView, std::shared_ptr<RdmaMemory::MutableView>>(
m, "RdmaMemoryMutableView");

py::class_<RdmaMemory, std::shared_ptr<RdmaMemory>>(m, "RdmaMemory")
.def(py::init([](const at::Tensor& tensor) {
TORCH_CHECK(
tensor.is_contiguous(),
"RdmaMemory currently requires a contiguous tensor");
// If CPU memory is passed, use device 0 for NIC discovery
const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device();
return std::make_shared<RdmaMemory>(
tensor.data_ptr(), tensor.nbytes(), device);
}))
.def(
"to_view",
[](RdmaMemory& self) {
return self.createView(size_t(0), self.length());
})
.def(
"to_mutable_view",
[](RdmaMemory& self) {
return self.createMutableView(size_t(0), self.length());
})
.def("to_remote_buffer", [](RdmaMemory& self) {
return RdmaRemoteBuffer{
const_cast<void*>(self.data()), self.remoteKey()};
});
}
33 changes: 33 additions & 0 deletions comms/torchcomms/transport/_transport.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# pyre-strict

import torch

class RdmaMemoryView:
def size(self) -> int: ...

class RdmaMemoryMutableView: ...

class RdmaRemoteBuffer:
def __getstate__(self) -> tuple[int, str]: ...
def __setstate__(self, state: tuple[int, str]) -> None: ...

class RdmaMemory:
def __init__(self, tensor: torch.Tensor) -> None: ...
def to_view(self) -> RdmaMemoryView: ...
def to_mutable_view(self) -> RdmaMemoryMutableView: ...
def to_remote_buffer(self) -> RdmaRemoteBuffer: ...

class RdmaTransport:
def __init__(self, device: torch.device) -> None: ...
@staticmethod
def supported() -> bool: ...
def bind(self) -> bytes: ...
def connect(self, peer_url: bytes) -> int: ...
def connected(self) -> bool: ...
def write(
self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer
) -> int: ...
def read(
self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer
) -> int: ...
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch
from torchcomms._comms_ncclx import RdmaMemory, RdmaTransport
from torchcomms._transport import RdmaMemory, RdmaTransport


class TransportTest(unittest.TestCase):
Expand Down