Skip to content

Commit 09219c1

Browse files
authored
Fix #651 (#662)
* Python cannot distinguish `Communicator::connect(const Endpoint&, ...)` from `Communicator::connect(const EndpointConfig&, ...)`. Temporarily removed the former one. * A few other fixes in Python bindings.
1 parent 68b1f15 commit 09219c1

File tree

5 files changed

+36
-41
lines changed

5 files changed

+36
-41
lines changed

python/mscclpp/core_py.cpp

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ void register_core(nb::module_& m) {
7070
.def_static("create_unique_id", &TcpBootstrap::createUniqueId)
7171
.def("get_unique_id", &TcpBootstrap::getUniqueId)
7272
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
73-
nb::call_guard<nb::gil_scoped_release>(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30)
73+
nb::call_guard<nb::gil_scoped_release>(), nb::arg("unique_id"), nb::arg("timeout_sec") = 30)
7474
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
75-
nb::call_guard<nb::gil_scoped_release>(), nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
75+
nb::call_guard<nb::gil_scoped_release>(), nb::arg("if_ip_port_trio"), nb::arg("timeout_sec") = 30);
7676

7777
nb::enum_<Transport>(m, "Transport")
7878
.value("Unknown", Transport::Unknown)
@@ -126,10 +126,10 @@ void register_core(nb::module_& m) {
126126

127127
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
128128
.def(nb::init<>())
129-
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
130-
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
131-
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
132-
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
129+
.def(nb::init<int, int, int, int>(), nb::arg("max_cq_size") = EndpointConfig::Ib::DefaultMaxCqSize,
130+
nb::arg("max_cq_poll_num") = EndpointConfig::Ib::DefaultMaxCqPollNum,
131+
nb::arg("max_send_wr") = EndpointConfig::Ib::DefaultMaxSendWr,
132+
nb::arg("max_wr_per_send") = EndpointConfig::Ib::DefaultMaxWrPerSend)
133133
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
134134
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
135135
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
@@ -144,6 +144,7 @@ void register_core(nb::module_& m) {
144144
.def_static("deserialize", &RegisteredMemory::deserialize, nb::arg("data"));
145145

146146
nb::class_<Endpoint>(m, "Endpoint")
147+
.def("config", &Endpoint::config)
147148
.def("transport", &Endpoint::transport)
148149
.def("device", &Endpoint::device)
149150
.def("max_write_queue_size", &Endpoint::maxWriteQueueSize)
@@ -158,8 +159,9 @@ void register_core(nb::module_& m) {
158159
[](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
159160
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
160161
},
161-
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
162-
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
162+
nb::arg("dst"), nb::arg("dst_offset"), nb::arg("src"), nb::arg("new_value"))
163+
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(),
164+
nb::arg("timeout_usec") = (int64_t)3e7)
163165
.def("transport", &Connection::transport)
164166
.def("remote_transport", &Connection::remoteTransport)
165167
.def("context", &Connection::context)
@@ -170,7 +172,7 @@ void register_core(nb::module_& m) {
170172
.def(nb::init<>())
171173
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
172174
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
173-
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
175+
nb::arg("max_write_queue_size") = -1, nb::arg("ib") = EndpointConfig::Ib{})
174176
.def_rw("transport", &EndpointConfig::transport)
175177
.def_rw("device", &EndpointConfig::device)
176178
.def_rw("ib", &EndpointConfig::ib)
@@ -192,7 +194,7 @@ void register_core(nb::module_& m) {
192194
.def_static("create", &Context::create)
193195
.def(
194196
"register_memory",
195-
[](Communicator* self, uintptr_t ptr, size_t size, TransportFlags transports) {
197+
[](Context* self, uintptr_t ptr, size_t size, TransportFlags transports) {
196198
return self->registerMemory((void*)ptr, size, transports);
197199
},
198200
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
@@ -207,7 +209,7 @@ void register_core(nb::module_& m) {
207209

208210
nb::class_<Semaphore>(m, "Semaphore")
209211
.def(nb::init<>())
210-
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("localStub"), nb::arg("remoteStub"))
212+
.def(nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg("local_stub"), nb::arg("remote_stub"))
211213
.def("connection", &Semaphore::connection)
212214
.def("local_memory", &Semaphore::localMemory)
213215
.def("remote_memory", &Semaphore::remoteMemory);
@@ -226,29 +228,21 @@ void register_core(nb::module_& m) {
226228
return self->registerMemory((void*)ptr, size, transports);
227229
},
228230
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
229-
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
230-
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
231+
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag") = 0)
232+
.def("recv_memory", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag") = 0)
231233
.def("connect",
232-
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
233-
&Communicator::connect),
234-
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
235-
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
236-
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
237-
.def(
238-
"connect",
239-
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
240-
return self->connect(localConfig, remoteRank, tag);
241-
},
242-
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
234+
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int,
235+
int)>(&Communicator::connect),
236+
nb::arg("local_config"), nb::arg("remote_rank"), nb::arg("tag") = 0)
243237
.def(
244238
"connect_on_setup",
245239
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
246240
return self->connect(std::move(localConfig), remoteRank, tag);
247241
},
248-
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
249-
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag"))
250-
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag"))
251-
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("localFlag"), nb::arg("remoteRank"),
242+
nb::arg("remote_rank"), nb::arg("tag"), nb::arg("local_config"))
243+
.def("send_memory_on_setup", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remote_rank"), nb::arg("tag"))
244+
.def("recv_memory_on_setup", &Communicator::recvMemory, nb::arg("remote_rank"), nb::arg("tag"))
245+
.def("build_semaphore", &Communicator::buildSemaphore, nb::arg("local_flag"), nb::arg("remote_rank"),
252246
nb::arg("tag") = 0)
253247
.def("remote_rank_of", &Communicator::remoteRankOf)
254248
.def("tag_of", &Communicator::tagOf)

python/mscclpp/executor_py.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ void register_executor(nb::module_& m) {
3737
self->execute(rank, reinterpret_cast<void*>(sendbuff), reinterpret_cast<void*>(recvBuff), sendBuffSize,
3838
recvBuffSize, dataType, plan, (cudaStream_t)stream, packetType);
3939
},
40-
nb::arg("rank"), nb::arg("sendbuff"), nb::arg("recvBuff"), nb::arg("sendBuffSize"), nb::arg("recvBuffSize"),
41-
nb::arg("dataType"), nb::arg("plan"), nb::arg("stream"), nb::arg("packetType") = PacketType::LL16);
40+
nb::arg("rank"), nb::arg("send_buff"), nb::arg("recv_buff"), nb::arg("send_buff_size"),
41+
nb::arg("recv_buff_size"), nb::arg("data_type"), nb::arg("plan"), nb::arg("stream"),
42+
nb::arg("packet_type") = PacketType::LL16);
4243
}

python/mscclpp/gpu_utils_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,5 @@ void register_gpu_utils(nb::module_& m) {
112112
[](GpuBuffer<char>& self, std::string dataType, std::vector<int64_t> shape, std::vector<int64_t> strides) {
113113
return toDlpack(self, dataType, shape, strides);
114114
},
115-
nb::arg("dataType"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
115+
nb::arg("data_type"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
116116
}

python/mscclpp/port_channel_py.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void register_port_channel(nb::module_& m) {
1616
.def("stop_proxy", &BaseProxyService::stopProxy);
1717

1818
nb::class_<ProxyService, BaseProxyService>(m, "ProxyService")
19-
.def(nb::init<int>(), nb::arg("fifoSize") = DEFAULT_FIFO_SIZE)
19+
.def(nb::init<int>(), nb::arg("fifo_size") = DEFAULT_FIFO_SIZE)
2020
.def("start_proxy", &ProxyService::startProxy)
2121
.def("stop_proxy", &ProxyService::stopProxy)
2222
.def("build_and_add_semaphore", &ProxyService::buildAndAddSemaphore, nb::arg("comm"), nb::arg("connection"))
@@ -34,12 +34,12 @@ void register_port_channel(nb::module_& m) {
3434
nb::class_<BasePortChannel>(m, "BasePortChannel")
3535
.def(nb::init<>())
3636
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>>(),
37-
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"))
37+
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"))
3838
.def("device_handle", &BasePortChannel::deviceHandle);
3939

4040
nb::class_<BasePortChannel::DeviceHandle>(m, "BasePortChannelDeviceHandle")
4141
.def(nb::init<>())
42-
.def_rw("semaphoreId_", &BasePortChannel::DeviceHandle::semaphoreId_)
42+
.def_rw("semaphore_id_", &BasePortChannel::DeviceHandle::semaphoreId_)
4343
.def_rw("semaphore_", &BasePortChannel::DeviceHandle::semaphore_)
4444
.def_rw("fifo_", &BasePortChannel::DeviceHandle::fifo_)
4545
.def_prop_ro("raw", [](const BasePortChannel::DeviceHandle& self) -> nb::bytes {
@@ -49,12 +49,12 @@ void register_port_channel(nb::module_& m) {
4949
nb::class_<PortChannel>(m, "PortChannel")
5050
.def(nb::init<>())
5151
.def(nb::init<SemaphoreId, std::shared_ptr<Host2DeviceSemaphore>, std::shared_ptr<Proxy>, MemoryId, MemoryId>(),
52-
nb::arg("semaphoreId"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
52+
nb::arg("semaphore_id"), nb::arg("semaphore"), nb::arg("proxy"), nb::arg("dst"), nb::arg("src"))
5353
.def("device_handle", &PortChannel::deviceHandle);
5454

5555
nb::class_<PortChannel::DeviceHandle>(m, "PortChannelDeviceHandle")
5656
.def(nb::init<>())
57-
.def_rw("semaphoreId_", &PortChannel::DeviceHandle::semaphoreId_)
57+
.def_rw("semaphore_id_", &PortChannel::DeviceHandle::semaphoreId_)
5858
.def_rw("semaphore_", &PortChannel::DeviceHandle::semaphore_)
5959
.def_rw("fifo_", &PortChannel::DeviceHandle::fifo_)
6060
.def_rw("src_", &PortChannel::DeviceHandle::src_)

python/mscclpp/switch_channel_py.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ void register_nvls(nb::module_& m) {
2121

2222
nb::class_<SwitchChannel::DeviceHandle>(m, "DeviceHandle")
2323
.def(nb::init<>())
24-
.def_rw("devicePtr", &SwitchChannel::DeviceHandle::devicePtr)
25-
.def_rw("mcPtr", &SwitchChannel::DeviceHandle::mcPtr)
24+
.def_rw("device_ptr", &SwitchChannel::DeviceHandle::devicePtr)
25+
.def_rw("mc_ptr", &SwitchChannel::DeviceHandle::mcPtr)
2626
.def_rw("size", &SwitchChannel::DeviceHandle::bufferSize)
2727
.def_prop_ro("raw", [](const SwitchChannel::DeviceHandle& self) -> nb::bytes {
2828
return nb::bytes(reinterpret_cast<const char*>(&self), sizeof(self));
2929
});
3030

3131
nb::class_<NvlsConnection>(m, "NvlsConnection")
32-
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("devicePtr"), nb::arg("size"))
32+
.def("bind_allocated_memory", &NvlsConnection::bindAllocatedMemory, nb::arg("device_ptr"), nb::arg("size"))
3333
.def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity);
3434

35-
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"),
36-
nb::arg("bufferSize"));
35+
m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("all_ranks"),
36+
nb::arg("buffer_size"));
3737
}

0 commit comments

Comments
 (0)