Skip to content

Commit d55ac96

Browse files
authored
Fixed the local channel test (#597)
1 parent 658411c commit d55ac96

File tree

4 files changed

+217
-206
lines changed

4 files changed

+217
-206
lines changed

include/mscclpp/core.hpp

Lines changed: 145 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ enum class DeviceType {
361361
GPU, // GPU device type.
362362
};
363363

364+
/// Declaration of a device.
364365
struct Device {
365366
/// Constructor.
366367
Device() = default;
@@ -377,8 +378,149 @@ struct Device {
377378
int id;
378379
};
379380

381+
/// Used to configure an endpoint.
382+
struct EndpointConfig {
383+
static const int DefaultMaxCqSize = 1024;
384+
static const int DefaultMaxCqPollNum = 1;
385+
static const int DefaultMaxSendWr = 8192;
386+
static const int DefaultMaxWrPerSend = 64;
387+
388+
Transport transport;
389+
Device device;
390+
int ibMaxCqSize;
391+
int ibMaxCqPollNum;
392+
int ibMaxSendWr;
393+
int ibMaxWrPerSend;
394+
int maxWriteQueueSize;
395+
396+
/// Constructor that takes a transport and sets the other fields to their default values.
397+
///
398+
/// @param transport The transport to use.
399+
/// @param device The device to use.
400+
/// @param ibMaxCqSize The maximum completion queue size.
401+
/// @param ibMaxCqPollNum The maximum completion queue poll number.
402+
/// @param ibMaxSendWr The maximum send work requests.
403+
/// @param ibMaxWrPerSend The maximum work requests per send.
404+
/// @param maxWriteQueueSize The maximum write queue size.
405+
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
406+
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
407+
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
408+
int maxWriteQueueSize = -1)
409+
: transport(transport),
410+
device(device),
411+
ibMaxCqSize(ibMaxCqSize),
412+
ibMaxCqPollNum(ibMaxCqPollNum),
413+
ibMaxSendWr(ibMaxSendWr),
414+
ibMaxWrPerSend(ibMaxWrPerSend),
415+
maxWriteQueueSize(maxWriteQueueSize) {}
416+
};
417+
380418
class Context;
381419
class Connection;
420+
class RegisteredMemory;
421+
class SemaphoreStub;
422+
423+
/// One end of a connection.
424+
class Endpoint {
425+
public:
426+
/// Constructor.
427+
Endpoint() = default;
428+
429+
/// Get the transport used.
430+
/// @return The transport used.
431+
Transport transport() const;
432+
433+
/// Get the device used.
434+
/// @return The device used.
435+
const Device& device() const;
436+
437+
/// Get the host hash.
438+
/// @return The host hash.
439+
uint64_t hostHash() const;
440+
441+
/// Get the process ID hash.
442+
/// @return The process ID hash.
443+
uint64_t pidHash() const;
444+
445+
/// Get the maximum write queue size.
446+
/// @return The maximum number of write requests that can be queued.
447+
int maxWriteQueueSize() const;
448+
449+
/// Serialize the Endpoint object to a vector of characters.
450+
/// @return A vector of characters representing the serialized Endpoint object.
451+
std::vector<char> serialize() const;
452+
453+
/// Deserialize an Endpoint object from a vector of characters.
454+
/// @param data A vector of characters representing a serialized Endpoint object.
455+
/// @return A deserialized Endpoint object.
456+
static Endpoint deserialize(const std::vector<char>& data);
457+
458+
private:
459+
struct Impl;
460+
Endpoint(std::shared_ptr<Impl> pimpl);
461+
std::shared_ptr<Impl> pimpl_;
462+
463+
friend class Context;
464+
friend class Connection;
465+
};
466+
467+
/// Context for communication. This provides a low-level interface for forming connections in use-cases
468+
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
469+
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
470+
/// connect() method.
471+
///
472+
/// As an example, a client-server scenario where the server will write to the client might proceed as follows:
473+
/// 1. The client creates an endpoint with createEndpoint() and sends it to the server.
474+
/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
475+
/// client, and creates a connection with connect().
476+
/// 3. The client receives the server endpoint, creates a connection with connect() and sends a
477+
/// RegisteredMemory to the server.
478+
/// 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
479+
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
480+
/// write to the RegisteredMemory before the connection is established.
481+
///
482+
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
483+
class Context : public std::enable_shared_from_this<Context> {
484+
public:
485+
/// Create a new Context instance.
486+
static std::shared_ptr<Context> create() { return std::shared_ptr<Context>(new Context()); }
487+
488+
/// Destructor.
489+
~Context();
490+
491+
/// Register a region of GPU memory for use in this context.
492+
///
493+
/// @param ptr Base pointer to the memory.
494+
/// @param size Size of the memory region in bytes.
495+
/// @param transports Transport flags.
496+
/// @return A RegisteredMemory object representing the registered memory region.
497+
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
498+
499+
/// Create an endpoint for establishing connections.
500+
///
501+
/// @param config The configuration for the endpoint.
502+
/// @return The newly created endpoint.
503+
Endpoint createEndpoint(EndpointConfig config);
504+
505+
/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
506+
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
507+
/// This method must be called on both endpoints to establish a connection.
508+
///
509+
/// @param localEndpoint The local endpoint.
510+
/// @param remoteEndpoint The remote endpoint.
511+
/// @return A shared pointer to the connection.
512+
std::shared_ptr<Connection> connect(const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
513+
514+
private:
515+
Context();
516+
517+
struct Impl;
518+
std::unique_ptr<Impl> pimpl_;
519+
520+
friend class Endpoint;
521+
friend class Connection;
522+
friend class RegisteredMemory;
523+
};
382524

383525
/// Block of memory that has been registered to a Context.
384526
/// RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
@@ -426,50 +568,6 @@ class RegisteredMemory {
426568
friend class SemaphoreStub;
427569
};
428570

429-
/// One end of a connection.
430-
class Endpoint {
431-
public:
432-
/// Constructor.
433-
Endpoint() = default;
434-
435-
/// Get the transport used.
436-
/// @return The transport used.
437-
Transport transport() const;
438-
439-
/// Get the device used.
440-
/// @return The device used.
441-
const Device& device() const;
442-
443-
/// Get the host hash.
444-
/// @return The host hash.
445-
uint64_t hostHash() const;
446-
447-
/// Get the process ID hash.
448-
/// @return The process ID hash.
449-
uint64_t pidHash() const;
450-
451-
/// Get the maximum write queue size.
452-
/// @return The maximum number of write requests that can be queued.
453-
int maxWriteQueueSize() const;
454-
455-
/// Serialize the Endpoint object to a vector of characters.
456-
/// @return A vector of characters representing the serialized Endpoint object.
457-
std::vector<char> serialize() const;
458-
459-
/// Deserialize an Endpoint object from a vector of characters.
460-
/// @param data A vector of characters representing a serialized Endpoint object.
461-
/// @return A deserialized Endpoint object.
462-
static Endpoint deserialize(const std::vector<char>& data);
463-
464-
private:
465-
struct Impl;
466-
Endpoint(std::shared_ptr<Impl> pimpl);
467-
std::shared_ptr<Impl> pimpl_;
468-
469-
friend class Context;
470-
friend class Connection;
471-
};
472-
473571
/// Connection between two processes.
474572
class Connection {
475573
public:
@@ -524,108 +622,15 @@ class Connection {
524622
int getMaxWriteQueueSize() const;
525623

526624
protected:
527-
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
528-
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
625+
static const Endpoint::Impl& getImpl(const Endpoint& endpoint);
626+
static const RegisteredMemory::Impl& getImpl(const RegisteredMemory& memory);
627+
static Context::Impl& getImpl(Context& context);
529628

530629
std::shared_ptr<Context> context_;
531630
Endpoint localEndpoint_;
532631
int maxWriteQueueSize_;
533632
};
534633

535-
/// Used to configure an endpoint.
536-
struct EndpointConfig {
537-
static const int DefaultMaxCqSize = 1024;
538-
static const int DefaultMaxCqPollNum = 1;
539-
static const int DefaultMaxSendWr = 8192;
540-
static const int DefaultMaxWrPerSend = 64;
541-
542-
Transport transport;
543-
Device device;
544-
int ibMaxCqSize;
545-
int ibMaxCqPollNum;
546-
int ibMaxSendWr;
547-
int ibMaxWrPerSend;
548-
int maxWriteQueueSize;
549-
550-
/// Constructor that takes a transport and sets the other fields to their default values.
551-
///
552-
/// @param transport The transport to use.
553-
/// @param device The device to use.
554-
/// @param ibMaxCqSize The maximum completion queue size.
555-
/// @param ibMaxCqPollNum The maximum completion queue poll number.
556-
/// @param ibMaxSendWr The maximum send work requests.
557-
/// @param ibMaxWrPerSend The maximum work requests per send.
558-
/// @param maxWriteQueueSize The maximum write queue size.
559-
EndpointConfig(Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
560-
int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
561-
int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
562-
int maxWriteQueueSize = -1)
563-
: transport(transport),
564-
device(device),
565-
ibMaxCqSize(ibMaxCqSize),
566-
ibMaxCqPollNum(ibMaxCqPollNum),
567-
ibMaxSendWr(ibMaxSendWr),
568-
ibMaxWrPerSend(ibMaxWrPerSend),
569-
maxWriteQueueSize(maxWriteQueueSize) {}
570-
};
571-
572-
/// Context for communication. This provides a low-level interface for forming connections in use-cases
573-
/// where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
574-
/// connections. Correct use of this class requires external synchronization when finalizing connections with the
575-
/// connect() method.
576-
///
577-
/// As an example, a client-server scenario where the server will write to the client might proceed as follows:
578-
/// 1. The client creates an endpoint with createEndpoint() and sends it to the server.
579-
/// 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
580-
/// client, and creates a connection with connect().
581-
/// 3. The client receives the server endpoint, creates a connection with connect() and sends a
582-
/// RegisteredMemory to the server.
583-
/// 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
584-
/// The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
585-
/// write to the RegisteredMemory before the connection is established.
586-
///
587-
/// While some transports may have more relaxed implementation behavior, this should not be relied upon.
588-
class Context : public std::enable_shared_from_this<Context> {
589-
public:
590-
/// Create a new Context instance.
591-
static std::shared_ptr<Context> create() { return std::shared_ptr<Context>(new Context()); }
592-
593-
/// Destructor.
594-
~Context();
595-
596-
/// Register a region of GPU memory for use in this context.
597-
///
598-
/// @param ptr Base pointer to the memory.
599-
/// @param size Size of the memory region in bytes.
600-
/// @param transports Transport flags.
601-
/// @return A RegisteredMemory object representing the registered memory region.
602-
RegisteredMemory registerMemory(void* ptr, size_t size, TransportFlags transports);
603-
604-
/// Create an endpoint for establishing connections.
605-
///
606-
/// @param config The configuration for the endpoint.
607-
/// @return The newly created endpoint.
608-
Endpoint createEndpoint(EndpointConfig config);
609-
610-
/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
611-
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
612-
/// This method must be called on both endpoints to establish a connection.
613-
///
614-
/// @param localEndpoint The local endpoint.
615-
/// @param remoteEndpoint The remote endpoint.
616-
/// @return A shared pointer to the connection.
617-
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);
618-
619-
private:
620-
Context();
621-
622-
struct Impl;
623-
std::unique_ptr<Impl> pimpl_;
624-
625-
friend class RegisteredMemory;
626-
friend class Endpoint;
627-
};
628-
629634
/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
630635
class SemaphoreStub {
631636
public:

0 commit comments

Comments
 (0)