diff --git a/comms/torchcomms/ncclx/CMakeLists.txt b/comms/torchcomms/ncclx/CMakeLists.txt index 069f3ba8..7950de19 100644 --- a/comms/torchcomms/ncclx/CMakeLists.txt +++ b/comms/torchcomms/ncclx/CMakeLists.txt @@ -2,7 +2,6 @@ # Extension: torchcomms._comms_ncclx file(GLOB TORCHCOMMS_NCCLX_SOURCES "comms/torchcomms/ncclx/*.cpp" - "comms/torchcomms/transport/*.cc" ) file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp") @@ -61,7 +60,6 @@ add_library(torchcomms_comms_ncclx MODULE ${TORCHCOMMS_NCCLX_SOURCES} ${TORCHCOMMS_CUDA_API_SOURCE} ) -target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM) set_target_properties(torchcomms_comms_ncclx PROPERTIES PREFIX "" OUTPUT_NAME "_comms_ncclx" diff --git a/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp b/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp index c6308fc1..a06c04e6 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLXPy.cpp @@ -1,7 +1,5 @@ // Copyright (c) Meta Platforms, Inc. and affiliates. -#include -#include #include #include #include @@ -9,106 +7,13 @@ #include #include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp" -#include "comms/torchcomms/transport/RdmaTransport.h" namespace py = pybind11; using namespace torch::comms; -namespace { -folly::ScopedEventBaseThread& getScopedEventBaseThread() { - // This intentionally creates and leaks a global event base thread to be used - // for all Transports on first use. - static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"}; - return scopedEventBaseThread; -} -} // namespace - PYBIND11_MODULE(_comms_ncclx, m) { m.doc() = "NCCLX specific python bindings for TorchComm"; py::class_>( m, "TorchCommNCCLX"); - - py::class_>( - m, "RdmaRemoteBuffer") - .def( - py::pickle( - [](const RdmaRemoteBuffer& buffer) { // __getstate__ - return py::make_tuple( - reinterpret_cast(buffer.ptr), buffer.accessKey); - }, - [](const py::tuple& t) { // __setstate__ - if (t.size() != 2) { - throw std::runtime_error( - "Invalid state for RdmaRemoteBuffer"); - } - return RdmaRemoteBuffer{ - reinterpret_cast(t[0].cast()), - t[1].cast()}; - })); - - py::class_>(m, "RdmaTransport") - // initialize a new RDMATransport using a custom init fn - .def(py::init([](at::Device device) { - TORCH_INTERNAL_ASSERT(device.is_cuda()); - int cuda_device = device.index(); - return std::make_shared( - cuda_device, getScopedEventBaseThread().getEventBase()); - })) - .def_static("supported", &RdmaTransport::supported) - .def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); }) - .def( - "connect", - [](RdmaTransport& self, const py::bytes& peerUrl) { - std::string peerUrlStr = peerUrl.cast(); - return static_cast(self.connect(peerUrlStr)); - }) - .def("connected", &RdmaTransport::connected) - .def( - "write", - [](RdmaTransport& self, - const RdmaMemory::View& localBuffer, - const RdmaRemoteBuffer& remoteBuffer) { - return static_cast( - self.write(localBuffer, remoteBuffer, false).get()); - }) - .def( - "read", - [](RdmaTransport& self, - RdmaMemory::MutableView& localBuffer, - const RdmaRemoteBuffer& remoteBuffer) { - return static_cast(self.read(localBuffer, remoteBuffer).get()); - }); - - py::class_>( - m, "RdmaMemoryView") - .def("size", &RdmaMemory::View::size); - - py::class_>( - m, "RdmaMemoryMutableView"); - - py::class_>(m, "RdmaMemory") - .def(py::init([](const at::Tensor& tensor) { - TORCH_CHECK( - tensor.is_contiguous(), - "RdmaMemory currently requires a contiguous tensor"); - // If CPU memory is passed, use device 0 for NIC discovery - const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device(); - return std::make_shared( - tensor.data_ptr(), tensor.nbytes(), device); - })) - .def( - "to_view", - [](RdmaMemory& self) { - return self.createView(size_t(0), self.length()); - }) - .def( - "to_mutable_view", - [](RdmaMemory& self) { - return self.createMutableView(size_t(0), self.length()); - }) - .def("to_remote_buffer", [](RdmaMemory& self) { - return RdmaRemoteBuffer{ - const_cast(self.data()), self.remoteKey()}; - }); } diff --git a/comms/torchcomms/ncclx/_comms_ncclx.pyi b/comms/torchcomms/ncclx/_comms_ncclx.pyi index f77396f9..34e007e1 100644 --- a/comms/torchcomms/ncclx/_comms_ncclx.pyi +++ b/comms/torchcomms/ncclx/_comms_ncclx.pyi @@ -1,35 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # pyre-strict -import torch - class TorchCommNCCLX: ... - -class RdmaMemoryView: - def size(self) -> int: ... - -class RdmaMemoryMutableView: ... - -class RdmaRemoteBuffer: - def __getstate__(self) -> tuple[int, str]: ... - def __setstate__(self, state: tuple[int, str]) -> None: ... - -class RdmaMemory: - def __init__(self, tensor: torch.Tensor) -> None: ... - def to_view(self) -> RdmaMemoryView: ... - def to_mutable_view(self) -> RdmaMemoryMutableView: ... - def to_remote_buffer(self) -> RdmaRemoteBuffer: ... - -class RdmaTransport: - def __init__(self, device: torch.device) -> None: ... - @staticmethod - def supported() -> bool: ... - def bind(self) -> bytes: ... - def connect(self, peer_url: bytes) -> int: ... - def connected(self) -> bool: ... - def write( - self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer - ) -> int: ... - def read( - self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer - ) -> int: ... diff --git a/comms/torchcomms/transport/RdmaTransport.cc b/comms/torchcomms/transport/RdmaTransport.cpp similarity index 100% rename from comms/torchcomms/transport/RdmaTransport.cc rename to comms/torchcomms/transport/RdmaTransport.cpp diff --git a/comms/torchcomms/transport/RdmaTransportPy.cpp b/comms/torchcomms/transport/RdmaTransportPy.cpp new file mode 100644 index 00000000..287ef145 --- /dev/null +++ b/comms/torchcomms/transport/RdmaTransportPy.cpp @@ -0,0 +1,110 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. + +#include +#include +#include +#include +#include +#include +#include + +#include "comms/torchcomms/transport/RdmaTransport.h" + +using namespace torch::comms; + +namespace { +folly::ScopedEventBaseThread& getScopedEventBaseThread() { + // This intentionally creates and leaks a global event base thread to be used + // for all Transports on first use. + static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"}; + return scopedEventBaseThread; +} +} // namespace + +PYBIND11_MODULE(_transport, m) { + m.doc() = "RdmaTransport python bindings for TorchComm"; + + py::class_>( + m, "RdmaRemoteBuffer") + .def( + py::pickle( + [](const RdmaRemoteBuffer& buffer) { // __getstate__ + return py::make_tuple( + reinterpret_cast(buffer.ptr), buffer.accessKey); + }, + [](const py::tuple& t) { // __setstate__ + if (t.size() != 2) { + throw std::runtime_error( + "Invalid state for RdmaRemoteBuffer"); + } + return RdmaRemoteBuffer{ + // NOLINTNEXTLINE(performance-no-int-to-ptr) + reinterpret_cast(t[0].cast()), + t[1].cast()}; + })); + + py::class_>(m, "RdmaTransport") + // initialize a new RDMATransport using a custom init fn + .def(py::init([](at::Device device) { + TORCH_INTERNAL_ASSERT(device.is_cuda()); + int cuda_device = device.index(); + return std::make_shared( + cuda_device, getScopedEventBaseThread().getEventBase()); + })) + .def_static("supported", &RdmaTransport::supported) + .def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); }) + .def( + "connect", + [](RdmaTransport& self, const py::bytes& peerUrl) { + std::string peerUrlStr = peerUrl.cast(); + return static_cast(self.connect(peerUrlStr)); + }) + .def("connected", &RdmaTransport::connected) + .def( + "write", + [](RdmaTransport& self, + const RdmaMemory::View& localBuffer, + const RdmaRemoteBuffer& remoteBuffer) { + return static_cast( + self.write(localBuffer, remoteBuffer, false).get()); + }) + .def( + "read", + [](RdmaTransport& self, + RdmaMemory::MutableView& localBuffer, + const RdmaRemoteBuffer& remoteBuffer) { + return static_cast(self.read(localBuffer, remoteBuffer).get()); + }); + + py::class_>( + m, "RdmaMemoryView") + .def("size", &RdmaMemory::View::size); + + py::class_>( + m, "RdmaMemoryMutableView"); + + py::class_>(m, "RdmaMemory") + .def(py::init([](const at::Tensor& tensor) { + TORCH_CHECK( + tensor.is_contiguous(), + "RdmaMemory currently requires a contiguous tensor"); + // If CPU memory is passed, use device 0 for NIC discovery + const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device(); + return std::make_shared( + tensor.data_ptr(), tensor.nbytes(), device); + })) + .def( + "to_view", + [](RdmaMemory& self) { + return self.createView(size_t(0), self.length()); + }) + .def( + "to_mutable_view", + [](RdmaMemory& self) { + return self.createMutableView(size_t(0), self.length()); + }) + .def("to_remote_buffer", [](RdmaMemory& self) { + return RdmaRemoteBuffer{ + const_cast(self.data()), self.remoteKey()}; + }); +} diff --git a/comms/torchcomms/transport/_transport.pyi b/comms/torchcomms/transport/_transport.pyi new file mode 100644 index 00000000..254522cd --- /dev/null +++ b/comms/torchcomms/transport/_transport.pyi @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# pyre-strict + +import torch + +class RdmaMemoryView: + def size(self) -> int: ... + +class RdmaMemoryMutableView: ... + +class RdmaRemoteBuffer: + def __getstate__(self) -> tuple[int, str]: ... + def __setstate__(self, state: tuple[int, str]) -> None: ... + +class RdmaMemory: + def __init__(self, tensor: torch.Tensor) -> None: ... + def to_view(self) -> RdmaMemoryView: ... + def to_mutable_view(self) -> RdmaMemoryMutableView: ... + def to_remote_buffer(self) -> RdmaRemoteBuffer: ... + +class RdmaTransport: + def __init__(self, device: torch.device) -> None: ... + @staticmethod + def supported() -> bool: ... + def bind(self) -> bytes: ... + def connect(self, peer_url: bytes) -> int: ... + def connected(self) -> bool: ... + def write( + self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer + ) -> int: ... + def read( + self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer + ) -> int: ... diff --git a/comms/torchcomms/transport/tests/RdmaTransportSupportXPlatTest.cc b/comms/torchcomms/transport/tests/cpp/RdmaTransportSupportXPlatTest.cc similarity index 100% rename from comms/torchcomms/transport/tests/RdmaTransportSupportXPlatTest.cc rename to comms/torchcomms/transport/tests/cpp/RdmaTransportSupportXPlatTest.cc diff --git a/comms/torchcomms/transport/tests/RdmaTransportTest.cc b/comms/torchcomms/transport/tests/cpp/RdmaTransportTest.cc similarity index 100% rename from comms/torchcomms/transport/tests/RdmaTransportTest.cc rename to comms/torchcomms/transport/tests/cpp/RdmaTransportTest.cc diff --git a/comms/torchcomms/tests/integration/py/TransportTest.py b/comms/torchcomms/transport/tests/py/TransportTest.py similarity index 99% rename from comms/torchcomms/tests/integration/py/TransportTest.py rename to comms/torchcomms/transport/tests/py/TransportTest.py index 2ad714f1..e2184bb4 100644 --- a/comms/torchcomms/tests/integration/py/TransportTest.py +++ b/comms/torchcomms/transport/tests/py/TransportTest.py @@ -7,7 +7,7 @@ import unittest import torch -from torchcomms._comms_ncclx import RdmaMemory, RdmaTransport +from torchcomms._transport import RdmaMemory, RdmaTransport class TransportTest(unittest.TestCase):