@@ -70,9 +70,9 @@ void register_core(nb::module_& m) {
7070 .def_static (" create_unique_id" , &TcpBootstrap::createUniqueId)
7171 .def (" get_unique_id" , &TcpBootstrap::getUniqueId)
7272 .def (" initialize" , static_cast <void (TcpBootstrap::*)(UniqueId, int64_t )>(&TcpBootstrap::initialize),
73- nb::call_guard<nb::gil_scoped_release>(), nb::arg (" uniqueId " ), nb::arg (" timeoutSec " ) = 30 )
73+ nb::call_guard<nb::gil_scoped_release>(), nb::arg (" unique_id " ), nb::arg (" timeout_sec " ) = 30 )
7474 .def (" initialize" , static_cast <void (TcpBootstrap::*)(const std::string&, int64_t )>(&TcpBootstrap::initialize),
75- nb::call_guard<nb::gil_scoped_release>(), nb::arg (" ifIpPortTrio " ), nb::arg (" timeoutSec " ) = 30 );
75+ nb::call_guard<nb::gil_scoped_release>(), nb::arg (" if_ip_port_trio " ), nb::arg (" timeout_sec " ) = 30 );
7676
7777 nb::enum_<Transport>(m, " Transport" )
7878 .value (" Unknown" , Transport::Unknown)
@@ -126,10 +126,10 @@ void register_core(nb::module_& m) {
126126
127127 nb::class_<EndpointConfig::Ib>(m, " EndpointConfigIb" )
128128 .def (nb::init<>())
129- .def (nb::init<int , int , int , int >(), nb::arg (" maxCqSize " ) = EndpointConfig::Ib::DefaultMaxCqSize,
130- nb::arg (" maxCqPollNum " ) = EndpointConfig::Ib::DefaultMaxCqPollNum,
131- nb::arg (" maxSendWr " ) = EndpointConfig::Ib::DefaultMaxSendWr,
132- nb::arg (" maxWrPerSend " ) = EndpointConfig::Ib::DefaultMaxWrPerSend)
129+ .def (nb::init<int , int , int , int >(), nb::arg (" max_cq_size " ) = EndpointConfig::Ib::DefaultMaxCqSize,
130+ nb::arg (" max_cq_poll_num " ) = EndpointConfig::Ib::DefaultMaxCqPollNum,
131+ nb::arg (" max_send_wr " ) = EndpointConfig::Ib::DefaultMaxSendWr,
132+ nb::arg (" max_wr_per_send " ) = EndpointConfig::Ib::DefaultMaxWrPerSend)
133133 .def_rw (" max_cq_size" , &EndpointConfig::Ib::maxCqSize)
134134 .def_rw (" max_cq_poll_num" , &EndpointConfig::Ib::maxCqPollNum)
135135 .def_rw (" max_send_wr" , &EndpointConfig::Ib::maxSendWr)
@@ -144,6 +144,7 @@ void register_core(nb::module_& m) {
144144 .def_static (" deserialize" , &RegisteredMemory::deserialize, nb::arg (" data" ));
145145
146146 nb::class_<Endpoint>(m, " Endpoint" )
147+ .def (" config" , &Endpoint::config)
147148 .def (" transport" , &Endpoint::transport)
148149 .def (" device" , &Endpoint::device)
149150 .def (" max_write_queue_size" , &Endpoint::maxWriteQueueSize)
@@ -158,8 +159,9 @@ void register_core(nb::module_& m) {
158159 [](Connection* self, RegisteredMemory dst, uint64_t dstOffset, uintptr_t src, uint64_t newValue) {
159160 self->updateAndSync (dst, dstOffset, (uint64_t *)src, newValue);
160161 },
161- nb::arg (" dst" ), nb::arg (" dstOffset" ), nb::arg (" src" ), nb::arg (" newValue" ))
162- .def (" flush" , &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg (" timeoutUsec" ) = (int64_t )3e7 )
162+ nb::arg (" dst" ), nb::arg (" dst_offset" ), nb::arg (" src" ), nb::arg (" new_value" ))
163+ .def (" flush" , &Connection::flush, nb::call_guard<nb::gil_scoped_release>(),
164+ nb::arg (" timeout_usec" ) = (int64_t )3e7 )
163165 .def (" transport" , &Connection::transport)
164166 .def (" remote_transport" , &Connection::remoteTransport)
165167 .def (" context" , &Connection::context)
@@ -170,7 +172,7 @@ void register_core(nb::module_& m) {
170172 .def (nb::init<>())
171173 .def (nb::init_implicit<Transport>(), nb::arg (" transport" ))
172174 .def (nb::init<Transport, Device, int , EndpointConfig::Ib>(), nb::arg (" transport" ), nb::arg (" device" ),
173- nb::arg (" maxWriteQueueSize " ) = -1 , nb::arg (" ib" ) = EndpointConfig::Ib{})
175+ nb::arg (" max_write_queue_size " ) = -1 , nb::arg (" ib" ) = EndpointConfig::Ib{})
174176 .def_rw (" transport" , &EndpointConfig::transport)
175177 .def_rw (" device" , &EndpointConfig::device)
176178 .def_rw (" ib" , &EndpointConfig::ib)
@@ -192,7 +194,7 @@ void register_core(nb::module_& m) {
192194 .def_static (" create" , &Context::create)
193195 .def (
194196 " register_memory" ,
195- [](Communicator * self, uintptr_t ptr, size_t size, TransportFlags transports) {
197+ [](Context * self, uintptr_t ptr, size_t size, TransportFlags transports) {
196198 return self->registerMemory ((void *)ptr, size, transports);
197199 },
198200 nb::arg (" ptr" ), nb::arg (" size" ), nb::arg (" transports" ))
@@ -207,7 +209,7 @@ void register_core(nb::module_& m) {
207209
208210 nb::class_<Semaphore>(m, " Semaphore" )
209211 .def (nb::init<>())
210- .def (nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg (" localStub " ), nb::arg (" remoteStub " ))
212+ .def (nb::init<const SemaphoreStub&, const SemaphoreStub&>(), nb::arg (" local_stub " ), nb::arg (" remote_stub " ))
211213 .def (" connection" , &Semaphore::connection)
212214 .def (" local_memory" , &Semaphore::localMemory)
213215 .def (" remote_memory" , &Semaphore::remoteMemory);
@@ -226,29 +228,21 @@ void register_core(nb::module_& m) {
226228 return self->registerMemory ((void *)ptr, size, transports);
227229 },
228230 nb::arg (" ptr" ), nb::arg (" size" ), nb::arg (" transports" ))
229- .def (" send_memory" , &Communicator::sendMemory, nb::arg (" memory" ), nb::arg (" remoteRank " ), nb::arg (" tag" ) = 0 )
230- .def (" recv_memory" , &Communicator::recvMemory, nb::arg (" remoteRank " ), nb::arg (" tag" ) = 0 )
231+ .def (" send_memory" , &Communicator::sendMemory, nb::arg (" memory" ), nb::arg (" remote_rank " ), nb::arg (" tag" ) = 0 )
232+ .def (" recv_memory" , &Communicator::recvMemory, nb::arg (" remote_rank " ), nb::arg (" tag" ) = 0 )
231233 .def (" connect" ,
232- static_cast <std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int , int )>(
233- &Communicator::connect),
234- nb::arg (" localEndpoint" ), nb::arg (" remoteRank" ), nb::arg (" tag" ) = 0 )
235- .def (" connect" , [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
236- int tag = 0 ) { return self->connect (localConfig, remoteRank, tag); })
237- .def (
238- " connect" ,
239- [](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
240- return self->connect (localConfig, remoteRank, tag);
241- },
242- nb::arg (" remoteRank" ), nb::arg (" tag" ), nb::arg (" localConfig" ))
234+ static_cast <std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const EndpointConfig&, int ,
235+ int )>(&Communicator::connect),
236+ nb::arg (" local_config" ), nb::arg (" remote_rank" ), nb::arg (" tag" ) = 0 )
243237 .def (
244238 " connect_on_setup" ,
245239 [](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
246240 return self->connect (std::move (localConfig), remoteRank, tag);
247241 },
248- nb::arg (" remoteRank " ), nb::arg (" tag" ), nb::arg (" localConfig " ))
249- .def (" send_memory_on_setup" , &Communicator::sendMemory, nb::arg (" memory" ), nb::arg (" remoteRank " ), nb::arg (" tag" ))
250- .def (" recv_memory_on_setup" , &Communicator::recvMemory, nb::arg (" remoteRank " ), nb::arg (" tag" ))
251- .def (" build_semaphore" , &Communicator::buildSemaphore, nb::arg (" localFlag " ), nb::arg (" remoteRank " ),
242+ nb::arg (" remote_rank " ), nb::arg (" tag" ), nb::arg (" local_config " ))
243+ .def (" send_memory_on_setup" , &Communicator::sendMemory, nb::arg (" memory" ), nb::arg (" remote_rank " ), nb::arg (" tag" ))
244+ .def (" recv_memory_on_setup" , &Communicator::recvMemory, nb::arg (" remote_rank " ), nb::arg (" tag" ))
245+ .def (" build_semaphore" , &Communicator::buildSemaphore, nb::arg (" local_flag " ), nb::arg (" remote_rank " ),
252246 nb::arg (" tag" ) = 0 )
253247 .def (" remote_rank_of" , &Communicator::remoteRankOf)
254248 .def (" tag_of" , &Communicator::tagOf)
0 commit comments