Skip to content

Commit 8ef1259

Browse files
tanquermeta-codesync[bot]
authored andcommitted
Move pybind out of ncclx (#68)
Summary: Pull Request resolved: #68 D85694262 added transport pybind layer at ncclx which is not approriate. They are very different libs. It introduce third party library and make internal conda image fail. Move it to RdmaTransport lib folder. Reviewed By: d4l3k Differential Revision: D87873651 fbshipit-source-id: fe9885bd6247d3048cd2c0a25907e6b967237ac9
1 parent 5ef73f5 commit 8ef1259

File tree

8 files changed

+144
-129
lines changed

8 files changed

+144
-129
lines changed

comms/torchcomms/ncclx/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Extension: torchcomms._comms_ncclx
33
file(GLOB TORCHCOMMS_NCCLX_SOURCES
44
"comms/torchcomms/ncclx/*.cpp"
5-
"comms/torchcomms/transport/*.cpp"
65
)
76
file(GLOB TORCHCOMMS_CUDA_API_SOURCE "comms/torchcomms/device/CudaApi.cpp")
87

@@ -61,7 +60,6 @@ add_library(torchcomms_comms_ncclx MODULE
6160
${TORCHCOMMS_NCCLX_SOURCES}
6261
${TORCHCOMMS_CUDA_API_SOURCE}
6362
)
64-
target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM)
6563
set_target_properties(torchcomms_comms_ncclx PROPERTIES
6664
PREFIX ""
6765
OUTPUT_NAME "_comms_ncclx"
Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,19 @@
11
// Copyright (c) Meta Platforms, Inc. and affiliates.
22

3-
#include <folly/io/async/EventBase.h>
4-
#include <folly/io/async/ScopedEventBaseThread.h>
53
#include <pybind11/chrono.h>
64
#include <pybind11/numpy.h>
75
#include <pybind11/pybind11.h>
86
#include <pybind11/stl.h>
97
#include <torch/csrc/utils/pybind.h>
108

119
#include "comms/torchcomms/ncclx/TorchCommNCCLX.hpp"
12-
#include "comms/torchcomms/transport/RdmaTransport.h"
1310

1411
namespace py = pybind11;
1512
using namespace torch::comms;
1613

17-
namespace {
18-
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
19-
// This intentionally creates and leaks a global event base thread to be used
20-
// for all Transports on first use.
21-
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
22-
return scopedEventBaseThread;
23-
}
24-
} // namespace
25-
2614
PYBIND11_MODULE(_comms_ncclx, m) {
2715
m.doc() = "NCCLX specific python bindings for TorchComm";
2816

2917
py::class_<TorchCommNCCLX, std::shared_ptr<TorchCommNCCLX>>(
3018
m, "TorchCommNCCLX");
31-
32-
py::class_<RdmaRemoteBuffer, std::shared_ptr<RdmaRemoteBuffer>>(
33-
m, "RdmaRemoteBuffer")
34-
.def(
35-
py::pickle(
36-
[](const RdmaRemoteBuffer& buffer) { // __getstate__
37-
return py::make_tuple(
38-
reinterpret_cast<uintptr_t>(buffer.ptr), buffer.accessKey);
39-
},
40-
[](const py::tuple& t) { // __setstate__
41-
if (t.size() != 2) {
42-
throw std::runtime_error(
43-
"Invalid state for RdmaRemoteBuffer");
44-
}
45-
return RdmaRemoteBuffer{
46-
reinterpret_cast<void*>(t[0].cast<uintptr_t>()),
47-
t[1].cast<std::string>()};
48-
}));
49-
50-
py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
51-
// initialize a new RDMATransport using a custom init fn
52-
.def(py::init([](at::Device device) {
53-
TORCH_INTERNAL_ASSERT(device.is_cuda());
54-
int cuda_device = device.index();
55-
return std::make_shared<RdmaTransport>(
56-
cuda_device, getScopedEventBaseThread().getEventBase());
57-
}))
58-
.def_static("supported", &RdmaTransport::supported)
59-
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
60-
.def(
61-
"connect",
62-
[](RdmaTransport& self, const py::bytes& peerUrl) {
63-
std::string peerUrlStr = peerUrl.cast<std::string>();
64-
return static_cast<int>(self.connect(peerUrlStr));
65-
})
66-
.def("connected", &RdmaTransport::connected)
67-
.def(
68-
"write",
69-
[](RdmaTransport& self,
70-
const RdmaMemory::View& localBuffer,
71-
const RdmaRemoteBuffer& remoteBuffer) {
72-
return static_cast<int>(
73-
self.write(localBuffer, remoteBuffer, false).get());
74-
})
75-
.def(
76-
"read",
77-
[](RdmaTransport& self,
78-
RdmaMemory::MutableView& localBuffer,
79-
const RdmaRemoteBuffer& remoteBuffer) {
80-
return static_cast<int>(self.read(localBuffer, remoteBuffer).get());
81-
});
82-
83-
py::class_<RdmaMemory::View, std::shared_ptr<RdmaMemory::View>>(
84-
m, "RdmaMemoryView")
85-
.def("size", &RdmaMemory::View::size);
86-
87-
py::class_<RdmaMemory::MutableView, std::shared_ptr<RdmaMemory::MutableView>>(
88-
m, "RdmaMemoryMutableView");
89-
90-
py::class_<RdmaMemory, std::shared_ptr<RdmaMemory>>(m, "RdmaMemory")
91-
.def(py::init([](const at::Tensor& tensor) {
92-
TORCH_CHECK(
93-
tensor.is_contiguous(),
94-
"RdmaMemory currently requires a contiguous tensor");
95-
// If CPU memory is passed, use device 0 for NIC discovery
96-
const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device();
97-
return std::make_shared<RdmaMemory>(
98-
tensor.data_ptr(), tensor.nbytes(), device);
99-
}))
100-
.def(
101-
"to_view",
102-
[](RdmaMemory& self) {
103-
return self.createView(size_t(0), self.length());
104-
})
105-
.def(
106-
"to_mutable_view",
107-
[](RdmaMemory& self) {
108-
return self.createMutableView(size_t(0), self.length());
109-
})
110-
.def("to_remote_buffer", [](RdmaMemory& self) {
111-
return RdmaRemoteBuffer{
112-
const_cast<void*>(self.data()), self.remoteKey()};
113-
});
11419
}
Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,4 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# pyre-strict
33

4-
import torch
5-
64
class TorchCommNCCLX: ...
7-
8-
class RdmaMemoryView:
9-
def size(self) -> int: ...
10-
11-
class RdmaMemoryMutableView: ...
12-
13-
class RdmaRemoteBuffer:
14-
def __getstate__(self) -> tuple[int, str]: ...
15-
def __setstate__(self, state: tuple[int, str]) -> None: ...
16-
17-
class RdmaMemory:
18-
def __init__(self, tensor: torch.Tensor) -> None: ...
19-
def to_view(self) -> RdmaMemoryView: ...
20-
def to_mutable_view(self) -> RdmaMemoryMutableView: ...
21-
def to_remote_buffer(self) -> RdmaRemoteBuffer: ...
22-
23-
class RdmaTransport:
24-
def __init__(self, device: torch.device) -> None: ...
25-
@staticmethod
26-
def supported() -> bool: ...
27-
def bind(self) -> bytes: ...
28-
def connect(self, peer_url: bytes) -> int: ...
29-
def connected(self) -> bool: ...
30-
def write(
31-
self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer
32-
) -> int: ...
33-
def read(
34-
self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer
35-
) -> int: ...
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
#include <folly/io/async/EventBase.h>
4+
#include <folly/io/async/ScopedEventBaseThread.h>
5+
#include <pybind11/chrono.h>
6+
#include <pybind11/numpy.h>
7+
#include <pybind11/pybind11.h>
8+
#include <pybind11/stl.h>
9+
#include <torch/csrc/utils/pybind.h>
10+
11+
#include "comms/torchcomms/transport/RdmaTransport.h"
12+
13+
using namespace torch::comms;
14+
15+
namespace {
16+
folly::ScopedEventBaseThread& getScopedEventBaseThread() {
17+
// This intentionally creates and leaks a global event base thread to be used
18+
// for all Transports on first use.
19+
static folly::ScopedEventBaseThread scopedEventBaseThread{"torchcomms_evb"};
20+
return scopedEventBaseThread;
21+
}
22+
} // namespace
23+
24+
PYBIND11_MODULE(_transport, m) {
25+
m.doc() = "RdmaTransport python bindings for TorchComm";
26+
27+
py::class_<RdmaRemoteBuffer, std::shared_ptr<RdmaRemoteBuffer>>(
28+
m, "RdmaRemoteBuffer")
29+
.def(
30+
py::pickle(
31+
[](const RdmaRemoteBuffer& buffer) { // __getstate__
32+
return py::make_tuple(
33+
reinterpret_cast<uintptr_t>(buffer.ptr), buffer.accessKey);
34+
},
35+
[](const py::tuple& t) { // __setstate__
36+
if (t.size() != 2) {
37+
throw std::runtime_error(
38+
"Invalid state for RdmaRemoteBuffer");
39+
}
40+
return RdmaRemoteBuffer{
41+
// NOLINTNEXTLINE(performance-no-int-to-ptr)
42+
reinterpret_cast<void*>(t[0].cast<uintptr_t>()),
43+
t[1].cast<std::string>()};
44+
}));
45+
46+
py::class_<RdmaTransport, std::shared_ptr<RdmaTransport>>(m, "RdmaTransport")
47+
// initialize a new RDMATransport using a custom init fn
48+
.def(py::init([](at::Device device) {
49+
TORCH_INTERNAL_ASSERT(device.is_cuda());
50+
int cuda_device = device.index();
51+
return std::make_shared<RdmaTransport>(
52+
cuda_device, getScopedEventBaseThread().getEventBase());
53+
}))
54+
.def_static("supported", &RdmaTransport::supported)
55+
.def("bind", [](RdmaTransport& self) { return py::bytes(self.bind()); })
56+
.def(
57+
"connect",
58+
[](RdmaTransport& self, const py::bytes& peerUrl) {
59+
std::string peerUrlStr = peerUrl.cast<std::string>();
60+
return static_cast<int>(self.connect(peerUrlStr));
61+
})
62+
.def("connected", &RdmaTransport::connected)
63+
.def(
64+
"write",
65+
[](RdmaTransport& self,
66+
const RdmaMemory::View& localBuffer,
67+
const RdmaRemoteBuffer& remoteBuffer) {
68+
return static_cast<int>(
69+
self.write(localBuffer, remoteBuffer, false).get());
70+
})
71+
.def(
72+
"read",
73+
[](RdmaTransport& self,
74+
RdmaMemory::MutableView& localBuffer,
75+
const RdmaRemoteBuffer& remoteBuffer) {
76+
return static_cast<int>(self.read(localBuffer, remoteBuffer).get());
77+
});
78+
79+
py::class_<RdmaMemory::View, std::shared_ptr<RdmaMemory::View>>(
80+
m, "RdmaMemoryView")
81+
.def("size", &RdmaMemory::View::size);
82+
83+
py::class_<RdmaMemory::MutableView, std::shared_ptr<RdmaMemory::MutableView>>(
84+
m, "RdmaMemoryMutableView");
85+
86+
py::class_<RdmaMemory, std::shared_ptr<RdmaMemory>>(m, "RdmaMemory")
87+
.def(py::init([](const at::Tensor& tensor) {
88+
TORCH_CHECK(
89+
tensor.is_contiguous(),
90+
"RdmaMemory currently requires a contiguous tensor");
91+
// If CPU memory is passed, use device 0 for NIC discovery
92+
const auto device = tensor.get_device() < 0 ? 0 : tensor.get_device();
93+
return std::make_shared<RdmaMemory>(
94+
tensor.data_ptr(), tensor.nbytes(), device);
95+
}))
96+
.def(
97+
"to_view",
98+
[](RdmaMemory& self) {
99+
return self.createView(size_t(0), self.length());
100+
})
101+
.def(
102+
"to_mutable_view",
103+
[](RdmaMemory& self) {
104+
return self.createMutableView(size_t(0), self.length());
105+
})
106+
.def("to_remote_buffer", [](RdmaMemory& self) {
107+
return RdmaRemoteBuffer{
108+
const_cast<void*>(self.data()), self.remoteKey()};
109+
});
110+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# pyre-strict
3+
4+
import torch
5+
6+
class RdmaMemoryView:
7+
def size(self) -> int: ...
8+
9+
class RdmaMemoryMutableView: ...
10+
11+
class RdmaRemoteBuffer:
12+
def __getstate__(self) -> tuple[int, str]: ...
13+
def __setstate__(self, state: tuple[int, str]) -> None: ...
14+
15+
class RdmaMemory:
16+
def __init__(self, tensor: torch.Tensor) -> None: ...
17+
def to_view(self) -> RdmaMemoryView: ...
18+
def to_mutable_view(self) -> RdmaMemoryMutableView: ...
19+
def to_remote_buffer(self) -> RdmaRemoteBuffer: ...
20+
21+
class RdmaTransport:
22+
def __init__(self, device: torch.device) -> None: ...
23+
@staticmethod
24+
def supported() -> bool: ...
25+
def bind(self) -> bytes: ...
26+
def connect(self, peer_url: bytes) -> int: ...
27+
def connected(self) -> bool: ...
28+
def write(
29+
self, local_buffer: RdmaMemoryView, remote_buffer: RdmaRemoteBuffer
30+
) -> int: ...
31+
def read(
32+
self, local_buffer: RdmaMemoryMutableView, remote_buffer: RdmaRemoteBuffer
33+
) -> int: ...

comms/torchcomms/tests/integration/py/TransportTest.py renamed to comms/torchcomms/transport/tests/py/TransportTest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import unittest
88

99
import torch
10-
from torchcomms._comms_ncclx import RdmaMemory, RdmaTransport
10+
from torchcomms._transport import RdmaMemory, RdmaTransport
1111

1212

1313
class TransportTest(unittest.TestCase):

0 commit comments

Comments
 (0)