Skip to content

Commit 200cdf9

Browse files
authored
Update EndpointConfig interfaces (#651)
* Separate IB-specific options into a nested struct * Enable `connect()` by an `Endpoint`, not only by `EndpointConfig` * Other minor changes
1 parent 610db6f commit 200cdf9

File tree

8 files changed

+155
-112
lines changed

8 files changed

+155
-112
lines changed

include/mscclpp/core.hpp

Lines changed: 59 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -374,41 +374,53 @@ struct Device {
374374
int id;
375375
};
376376

377-
/// Used to configure an endpoint.
377+
/// Configuration for creating communication endpoints.
378378
struct EndpointConfig {
379-
static const int DefaultMaxCqSize = 1024;
380-
static const int DefaultMaxCqPollNum = 1;
381-
static const int DefaultMaxSendWr = 8192;
382-
static const int DefaultMaxWrPerSend = 64;
383-
379+
/// InfiniBand-specific configuration options that control queue pair behavior and performance characteristics.
380+
/// These settings are only used when the transport is an InfiniBand type (IB0-IB7); they are ignored for other
381+
/// transports.
382+
struct Ib {
383+
static const int DefaultMaxCqSize = 1024;
384+
static const int DefaultMaxCqPollNum = 1;
385+
static const int DefaultMaxSendWr = 8192;
386+
static const int DefaultMaxWrPerSend = 64;
387+
388+
/// Maximum size of the completion queue.
389+
int maxCqSize;
390+
/// Maximum number of completion queue polls per operation.
391+
int maxCqPollNum;
392+
/// Maximum number of outstanding send work requests.
393+
int maxSendWr;
394+
/// Maximum number of work requests per send operation.
395+
int maxWrPerSend;
396+
397+
/// Constructor.
398+
/// @param maxCqSize Maximum completion queue size.
399+
/// @param maxCqPollNum Maximum completion queue poll count.
400+
/// @param maxSendWr Maximum outstanding send work requests.
401+
/// @param maxWrPerSend Maximum work requests per send operation.
402+
Ib(int maxCqSize = DefaultMaxCqSize, int maxCqPollNum = DefaultMaxCqPollNum, int maxSendWr = DefaultMaxSendWr,
403+
int maxWrPerSend = DefaultMaxWrPerSend)
404+
: maxCqSize(maxCqSize), maxCqPollNum(maxCqPollNum), maxSendWr(maxSendWr), maxWrPerSend(maxWrPerSend) {}
405+
};
406+
407+
/// Communication transport type (e.g., CudaIpc, IB0-IB7, Ethernet).
384408
Transport transport;
409+
/// Target device for the endpoint (GPU or CPU with optional device ID).
385410
Device device;
386-
int ibMaxCqSize;
387-
int ibMaxCqPollNum;
388-
int ibMaxSendWr;
389-
int ibMaxWrPerSend;
411+
/// Maximum number of write requests that can be queued (-1 for default).
390412
int maxWriteQueueSize;
391-
392-
/// Constructor that takes a transport and sets the other fields to their default values.
393-
///
394-
/// @param transport The transport to use.
395-
/// @param device The device to use.
396-
/// @param ibMaxCqSize The maximum completion queue size.
397-
/// @param ibMaxCqPollNum The maximum completion queue poll number.
398-
/// @param ibMaxSendWr The maximum send work requests.
399-
/// @param ibMaxWrPerSend The maximum work requests per send.
400-
/// @param maxWriteQueueSize The maximum write queue size.
401-
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
402-
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
403-
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
404-
int maxWriteQueueSize = -1)
405-
: transport(transport),
406-
device(device),
407-
ibMaxCqSize(ibMaxCqSize),
408-
ibMaxCqPollNum(ibMaxCqPollNum),
409-
ibMaxSendWr(ibMaxSendWr),
410-
ibMaxWrPerSend(ibMaxWrPerSend),
411-
maxWriteQueueSize(maxWriteQueueSize) {}
413+
/// InfiniBand-specific options (used only for Transport::IBx).
414+
Ib ib;
415+
416+
/// Constructs endpoint configuration with specified transport, device, and optional settings.
417+
/// @param transport Communication transport to use.
418+
/// @param device Target device for the endpoint.
419+
/// @param maxWriteQueueSize Maximum write queue size (-1 for system default).
420+
/// @param ib IB-specific configuration.
421+
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU, int maxWriteQueueSize = -1,
422+
Ib ib = {})
423+
: transport(transport), device(device), maxWriteQueueSize(maxWriteQueueSize), ib(ib) {}
412424
};
413425

414426
class Context;
@@ -423,6 +435,10 @@ class Endpoint {
423435
/// Constructor.
424436
Endpoint() = default;
425437

438+
/// Get the configuration used to create the endpoint.
439+
/// @return The configuration used to create the endpoint.
440+
const EndpointConfig& config() const;
441+
426442
/// Get the transport used.
427443
/// @return The transport used.
428444
Transport transport() const;
@@ -685,9 +701,9 @@ class Semaphore {
685701
std::shared_ptr<Impl> pimpl_;
686702
};
687703

704+
/// Deprecated.
688705
template <typename T>
689-
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
690-
std::shared_future<T>;
706+
using NonblockingFuture = std::shared_future<T>;
691707

692708
/// A class that sets up all registered memories and connections between processes.
693709
///
@@ -853,12 +869,20 @@ class Communicator {
853869
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
854870
/// back to back.
855871
///
856-
/// @param localConfig The configuration for the local endpoint.
872+
/// @param localEndpoint The local endpoint.
857873
/// @param remoteRank The rank of the remote process.
858874
/// @param tag The tag to use for identifying the send and receive.
859875
/// @return A future of shared pointer to the connection.
860876
///
861-
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
877+
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);
878+
879+
/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
880+
/// @param localConfig The configuration for the local endpoint.
881+
/// @param remoteRank The rank of the remote process.
882+
/// @param tag The tag to use for identifying the send and receive.
883+
/// @return A future of shared pointer to the connection.
884+
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
885+
int tag = 0);
862886

863887
[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
864888
shared_future<std::shared_ptr<Connection>>

python/mscclpp/core_py.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
124124
.def_rw("id", &Device::id)
125125
.def("__str__", [](const Device& self) { return std::to_string(self); });
126126

127+
nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
128+
.def(nb::init<>())
129+
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
130+
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
131+
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
132+
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
133+
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
134+
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
135+
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
136+
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);
137+
127138
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
128139
.def(nb::init<>())
129140
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
@@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
158169
nb::class_<EndpointConfig>(m, "EndpointConfig")
159170
.def(nb::init<>())
160171
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
161-
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
162-
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
163-
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
164-
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
165-
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
172+
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
173+
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
166174
.def_rw("transport", &EndpointConfig::transport)
167175
.def_rw("device", &EndpointConfig::device)
168-
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
169-
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
170-
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
171-
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
176+
.def_rw("ib", &EndpointConfig::ib)
177+
.def_prop_rw(
178+
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
179+
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
180+
.def_prop_rw(
181+
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
182+
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
183+
.def_prop_rw(
184+
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
185+
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
186+
.def_prop_rw(
187+
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
188+
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
172189
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);
173190

174191
nb::class_<Context>(m, "Context")
@@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
212229
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
213230
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
214231
.def("connect",
215-
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
232+
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
216233
&Communicator::connect),
217-
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
234+
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
235+
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
236+
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
218237
.def(
219238
"connect",
220-
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
221-
return self->connect(std::move(localConfig), remoteRank, tag);
239+
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
240+
return self->connect(localConfig, remoteRank, tag);
222241
},
223242
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
224243
.def(

src/communicator.cc

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,41 +99,44 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
9999
return shared_future;
100100
}
101101

102-
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
102+
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
103103
int remoteRank, int tag) {
104-
auto localEndpoint = context()->createEndpoint(localConfig);
105-
106104
if (remoteRank == bootstrap()->getRank()) {
107105
// Connection to self
108-
auto remoteEndpoint = context()->createEndpoint(localConfig);
106+
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
109107
auto connection = context()->connect(localEndpoint, remoteEndpoint);
110108
std::promise<std::shared_ptr<Connection>> promise;
111109
promise.set_value(connection);
112110
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
113-
return std::shared_future<std::shared_ptr<Connection>>(std::move(promise.get_future()));
111+
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
114112
}
115113

116114
bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);
117115

118-
auto future =
119-
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
120-
localEndpoint = std::move(localEndpoint)]() mutable {
121-
if (lastRecvItem) {
122-
// Recursive call to the previous receive items
123-
lastRecvItem->wait();
124-
}
125-
std::vector<char> data;
126-
bootstrap()->recv(data, remoteRank, tag);
127-
auto remoteEndpoint = Endpoint::deserialize(data);
128-
auto connection = context()->connect(localEndpoint, remoteEndpoint);
129-
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
130-
return connection;
131-
});
116+
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint,
117+
lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable {
118+
if (lastRecvItem) {
119+
// Recursive call to the previous receive items
120+
lastRecvItem->wait();
121+
}
122+
std::vector<char> data;
123+
bootstrap()->recv(data, remoteRank, tag);
124+
auto remoteEndpoint = Endpoint::deserialize(data);
125+
auto connection = context()->connect(localEndpoint, remoteEndpoint);
126+
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
127+
return connection;
128+
});
132129
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
133130
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
134131
return shared_future;
135132
}
136133

134+
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
135+
int remoteRank, int tag) {
136+
auto localEndpoint = context()->createEndpoint(localConfig);
137+
return connect(localEndpoint, remoteRank, tag);
138+
}
139+
137140
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
138141
EndpointConfig localConfig) {
139142
return connect(localConfig, remoteRank, tag);

src/connection.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,6 @@ IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& loc
167167
transport_(localEndpoint.transport()),
168168
remoteTransport_(remoteEndpoint.transport()),
169169
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
170-
if (maxWriteQueueSize_ == -1) {
171-
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
172-
}
173170
qp_ = getImpl(localEndpoint).ibQp_;
174171
qp_.lock()->rtr(getImpl(remoteEndpoint).ibQpInfo_);
175172
qp_.lock()->rts();

src/context.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "context.hpp"
55

66
#include <mscclpp/env.hpp>
7+
#include <sstream>
78

89
#include "api.h"
910
#include "connection.hpp"
@@ -76,21 +77,21 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
7677
if (remoteEndpoint.device().type == DeviceType::GPU && remoteEndpoint.device().id < 0) {
7778
throw Error("No GPU device ID provided for remote endpoint", ErrorCode::InvalidUsage);
7879
}
80+
auto localTransport = localEndpoint.transport();
81+
auto remoteTransport = remoteEndpoint.transport();
82+
if (localTransport != remoteTransport &&
83+
!(AllIBTransports.has(localTransport) && AllIBTransports.has(remoteTransport))) {
84+
std::stringstream ss;
85+
ss << "Transport mismatch between local (" << std::to_string(localTransport) << ") and remote ("
86+
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
87+
throw Error(ss.str(), ErrorCode::InvalidUsage);
88+
}
7989
std::shared_ptr<Connection> conn;
80-
if (localEndpoint.transport() == Transport::CudaIpc) {
81-
if (remoteEndpoint.transport() != Transport::CudaIpc) {
82-
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
83-
}
90+
if (localTransport == Transport::CudaIpc) {
8491
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
85-
} else if (AllIBTransports.has(localEndpoint.transport())) {
86-
if (!AllIBTransports.has(remoteEndpoint.transport())) {
87-
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
88-
}
92+
} else if (AllIBTransports.has(localTransport)) {
8993
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
90-
} else if (localEndpoint.transport() == Transport::Ethernet) {
91-
if (remoteEndpoint.transport() != Transport::Ethernet) {
92-
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
93-
}
94+
} else if (localTransport == Transport::Ethernet) {
9495
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
9596
} else {
9697
throw Error("Unsupported transport", ErrorCode::InternalError);

0 commit comments

Comments
 (0)