Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,41 +374,53 @@ struct Device {
int id;
};

/// Used to configure an endpoint.
/// Configuration for creating communication endpoints.
struct EndpointConfig {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;

/// InfiniBand-specific configuration options that control queue pair behavior and performance characteristics.
/// These settings are only used when the transport is an InfiniBand type (IB0-IB7); they are ignored for other
/// transports.
struct Ib {
static const int DefaultMaxCqSize = 1024;
static const int DefaultMaxCqPollNum = 1;
static const int DefaultMaxSendWr = 8192;
static const int DefaultMaxWrPerSend = 64;

/// Maximum size of the completion queue.
int maxCqSize;
/// Maximum number of completion queue polls per operation.
int maxCqPollNum;
/// Maximum number of outstanding send work requests.
int maxSendWr;
/// Maximum number of work requests per send operation.
int maxWrPerSend;

/// Constructor.
/// @param maxCqSize Maximum completion queue size.
/// @param maxCqPollNum Maximum completion queue poll count.
/// @param maxSendWr Maximum outstanding send work requests.
/// @param maxWrPerSend Maximum work requests per send operation.
Ib(int maxCqSize = DefaultMaxCqSize, int maxCqPollNum = DefaultMaxCqPollNum, int maxSendWr = DefaultMaxSendWr,
int maxWrPerSend = DefaultMaxWrPerSend)
: maxCqSize(maxCqSize), maxCqPollNum(maxCqPollNum), maxSendWr(maxSendWr), maxWrPerSend(maxWrPerSend) {}
};

/// Communication transport type (e.g., CudaIpc, IB0-IB7, Ethernet).
Transport transport;
/// Target device for the endpoint (GPU or CPU with optional device ID).
Device device;
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
int ibMaxWrPerSend;
/// Maximum number of write requests that can be queued (-1 for default).
int maxWriteQueueSize;

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
/// @param device The device to use.
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
int maxWriteQueueSize = -1)
: transport(transport),
device(device),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
ibMaxWrPerSend(ibMaxWrPerSend),
maxWriteQueueSize(maxWriteQueueSize) {}
/// InfiniBand-specific options (used only for Transport::IBx).
Ib ib;

/// Constructs endpoint configuration with specified transport, device, and optional settings.
/// @param transport Communication transport to use.
/// @param device Target device for the endpoint.
/// @param maxWriteQueueSize Maximum write queue size (-1 for system default).
/// @param ib IB-specific configuration.
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU, int maxWriteQueueSize = -1,
Ib ib = {})
: transport(transport), device(device), maxWriteQueueSize(maxWriteQueueSize), ib(ib) {}
};

class Context;
Expand All @@ -423,6 +435,10 @@ class Endpoint {
/// Constructor.
Endpoint() = default;

/// Get the configuration used to create the endpoint.
/// @return The configuration used to create the endpoint.
const EndpointConfig& config() const;

/// Get the transport used.
/// @return The transport used.
Transport transport() const;
Expand Down Expand Up @@ -685,9 +701,9 @@ class Semaphore {
std::shared_ptr<Impl> pimpl_;
};

/// Deprecated.
template <typename T>
using NonblockingFuture [[deprecated("Use std::shared_future instead. This will be removed in a future release.")]] =
std::shared_future<T>;
using NonblockingFuture = std::shared_future<T>;

/// A class that sets up all registered memories and connections between processes.
///
Expand Down Expand Up @@ -853,12 +869,20 @@ class Communicator {
/// on the last future, it will start receiving the five RegisteredMemory or Connection objects in order,
/// back to back.
///
/// @param localConfig The configuration for the local endpoint.
/// @param localEndpoint The local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
std::shared_future<std::shared_ptr<Connection>> connect(const Endpoint& localEndpoint, int remoteRank, int tag = 0);

/// Connect to a remote rank. Wrapper of `connect(localEndpoint, remoteRank, tag)`.
/// @param localConfig The configuration for the local endpoint.
/// @param remoteRank The rank of the remote process.
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
std::shared_future<std::shared_ptr<Connection>> connect(const EndpointConfig& localConfig, int remoteRank,
int tag = 0);

[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>
Expand Down
45 changes: 32 additions & 13 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ void register_core(nb::module_& m) {
.def_rw("id", &Device::id)
.def("__str__", [](const Device& self) { return std::to_string(self); });

nb::class_<EndpointConfig::Ib>(m, "EndpointConfigIb")
.def(nb::init<>())
.def(nb::init<int, int, int, int>(), nb::arg("maxCqSize") = EndpointConfig::Ib::DefaultMaxCqSize,
nb::arg("maxCqPollNum") = EndpointConfig::Ib::DefaultMaxCqPollNum,
nb::arg("maxSendWr") = EndpointConfig::Ib::DefaultMaxSendWr,
nb::arg("maxWrPerSend") = EndpointConfig::Ib::DefaultMaxWrPerSend)
.def_rw("max_cq_size", &EndpointConfig::Ib::maxCqSize)
.def_rw("max_cq_poll_num", &EndpointConfig::Ib::maxCqPollNum)
.def_rw("max_send_wr", &EndpointConfig::Ib::maxSendWr)
.def_rw("max_wr_per_send", &EndpointConfig::Ib::maxWrPerSend);

nb::class_<RegisteredMemory>(m, "RegisteredMemory")
.def(nb::init<>())
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
Expand Down Expand Up @@ -158,17 +169,23 @@ void register_core(nb::module_& m) {
nb::class_<EndpointConfig>(m, "EndpointConfig")
.def(nb::init<>())
.def(nb::init_implicit<Transport>(), nb::arg("transport"))
.def(nb::init<Transport, Device, int, int, int, int, int>(), nb::arg("transport"), nb::arg("device"),
nb::arg("ibMaxCqSize") = EndpointConfig::DefaultMaxCqSize,
nb::arg("ibMaxCqPollNum") = EndpointConfig::DefaultMaxCqPollNum,
nb::arg("ibMaxSendWr") = EndpointConfig::DefaultMaxSendWr,
nb::arg("ibMaxWrPerSend") = EndpointConfig::DefaultMaxWrPerSend, nb::arg("maxWriteQueueSize") = -1)
.def(nb::init<Transport, Device, int, EndpointConfig::Ib>(), nb::arg("transport"), nb::arg("device"),
nb::arg("maxWriteQueueSize") = -1, nb::arg("ib") = EndpointConfig::Ib{})
.def_rw("transport", &EndpointConfig::transport)
.def_rw("device", &EndpointConfig::device)
.def_rw("ib_max_cq_size", &EndpointConfig::ibMaxCqSize)
.def_rw("ib_max_cq_poll_num", &EndpointConfig::ibMaxCqPollNum)
.def_rw("ib_max_send_wr", &EndpointConfig::ibMaxSendWr)
.def_rw("ib_max_wr_per_send", &EndpointConfig::ibMaxWrPerSend)
.def_rw("ib", &EndpointConfig::ib)
.def_prop_rw(
"ib_max_cq_size", [](EndpointConfig& self) { return self.ib.maxCqSize; },
[](EndpointConfig& self, int v) { self.ib.maxCqSize = v; })
.def_prop_rw(
"ib_max_cq_poll_num", [](EndpointConfig& self) { return self.ib.maxCqPollNum; },
[](EndpointConfig& self, int v) { self.ib.maxCqPollNum = v; })
.def_prop_rw(
"ib_max_send_wr", [](EndpointConfig& self) { return self.ib.maxSendWr; },
[](EndpointConfig& self, int v) { self.ib.maxSendWr = v; })
.def_prop_rw(
"ib_max_wr_per_send", [](EndpointConfig& self) { return self.ib.maxWrPerSend; },
[](EndpointConfig& self, int v) { self.ib.maxWrPerSend = v; })
.def_rw("max_write_queue_size", &EndpointConfig::maxWriteQueueSize);

nb::class_<Context>(m, "Context")
Expand Down Expand Up @@ -212,13 +229,15 @@ void register_core(nb::module_& m) {
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(const Endpoint&, int, int)>(
&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
nb::arg("localEndpoint"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect", [](Communicator* self, const EndpointConfig& localConfig, int remoteRank,
int tag = 0) { return self->connect(localConfig, remoteRank, tag); })
.def(
"connect",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
return self->connect(std::move(localConfig), remoteRank, tag);
[](Communicator* self, int remoteRank, int tag, const EndpointConfig& localConfig) {
return self->connect(localConfig, remoteRank, tag);
},
nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig"))
.def(
Expand Down
41 changes: 22 additions & 19 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,41 +99,44 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
return shared_future;
}

MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const Endpoint& localEndpoint,
int remoteRank, int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);

if (remoteRank == bootstrap()->getRank()) {
// Connection to self
auto remoteEndpoint = context()->createEndpoint(localConfig);
auto remoteEndpoint = context()->createEndpoint(localEndpoint.config());
auto connection = context()->connect(localEndpoint, remoteEndpoint);
std::promise<std::shared_ptr<Connection>> promise;
promise.set_value(connection);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return std::shared_future<std::shared_ptr<Connection>>(std::move(promise.get_future()));
return std::shared_future<std::shared_ptr<Connection>>(promise.get_future());
}

bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);

auto future =
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localEndpoint = std::move(localEndpoint)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto future = std::async(std::launch::deferred, [this, remoteRank, tag, localEndpoint,
lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
auto shared_future = std::shared_future<std::shared_ptr<Connection>>(std::move(future));
pimpl_->setLastRecvItem(remoteRank, tag, std::make_shared<RecvItem<std::shared_ptr<Connection>>>(shared_future));
return shared_future;
}

MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(const EndpointConfig& localConfig,
int remoteRank, int tag) {
auto localEndpoint = context()->createEndpoint(localConfig);
return connect(localEndpoint, remoteRank, tag);
}

MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(int remoteRank, int tag,
EndpointConfig localConfig) {
return connect(localConfig, remoteRank, tag);
Expand Down
3 changes: 0 additions & 3 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ IBConnection::IBConnection(std::shared_ptr<Context> context, const Endpoint& loc
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
if (maxWriteQueueSize_ == -1) {
maxWriteQueueSize_ = EndpointConfig::DefaultMaxCqSize;
}
qp_ = getImpl(localEndpoint).ibQp_;
qp_.lock()->rtr(getImpl(remoteEndpoint).ibQpInfo_);
qp_.lock()->rts();
Expand Down
25 changes: 13 additions & 12 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "context.hpp"

#include <mscclpp/env.hpp>
#include <sstream>

#include "api.h"
#include "connection.hpp"
Expand Down Expand Up @@ -76,21 +77,21 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(const Endpoint &loc
if (remoteEndpoint.device().type == DeviceType::GPU && remoteEndpoint.device().id < 0) {
throw Error("No GPU device ID provided for remote endpoint", ErrorCode::InvalidUsage);
}
auto localTransport = localEndpoint.transport();
auto remoteTransport = remoteEndpoint.transport();
if (localTransport != remoteTransport &&
!(AllIBTransports.has(localTransport) && AllIBTransports.has(remoteTransport))) {
std::stringstream ss;
ss << "Transport mismatch between local (" << std::to_string(localTransport) << ") and remote ("
<< std::to_string(remoteEndpoint.transport()) << ") endpoints";
throw Error(ss.str(), ErrorCode::InvalidUsage);
}
std::shared_ptr<Connection> conn;
if (localEndpoint.transport() == Transport::CudaIpc) {
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
}
if (localTransport == Transport::CudaIpc) {
conn = std::make_shared<CudaIpcConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (AllIBTransports.has(localEndpoint.transport())) {
if (!AllIBTransports.has(remoteEndpoint.transport())) {
throw Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
}
} else if (AllIBTransports.has(localTransport)) {
conn = std::make_shared<IBConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else if (localEndpoint.transport() == Transport::Ethernet) {
if (remoteEndpoint.transport() != Transport::Ethernet) {
throw Error("Local transport is Ethernet but remote is not", ErrorCode::InvalidUsage);
}
} else if (localTransport == Transport::Ethernet) {
conn = std::make_shared<EthernetConnection>(shared_from_this(), localEndpoint, remoteEndpoint);
} else {
throw Error("Unsupported transport", ErrorCode::InternalError);
Expand Down
Loading
Loading