@@ -361,6 +361,7 @@ enum class DeviceType {
361
361
GPU, // GPU device type.
362
362
};
363
363
364
+ // / Declaration of a device.
364
365
struct Device {
365
366
// / Constructor.
366
367
Device () = default ;
@@ -377,8 +378,149 @@ struct Device {
377
378
int id;
378
379
};
379
380
381
+ // / Used to configure an endpoint.
382
+ struct EndpointConfig {
383
+ static const int DefaultMaxCqSize = 1024 ;
384
+ static const int DefaultMaxCqPollNum = 1 ;
385
+ static const int DefaultMaxSendWr = 8192 ;
386
+ static const int DefaultMaxWrPerSend = 64 ;
387
+
388
+ Transport transport;
389
+ Device device;
390
+ int ibMaxCqSize;
391
+ int ibMaxCqPollNum;
392
+ int ibMaxSendWr;
393
+ int ibMaxWrPerSend;
394
+ int maxWriteQueueSize;
395
+
396
+ // / Constructor that takes a transport and sets the other fields to their default values.
397
+ // /
398
+ // / @param transport The transport to use.
399
+ // / @param device The device to use.
400
+ // / @param ibMaxCqSize The maximum completion queue size.
401
+ // / @param ibMaxCqPollNum The maximum completion queue poll number.
402
+ // / @param ibMaxSendWr The maximum send work requests.
403
+ // / @param ibMaxWrPerSend The maximum work requests per send.
404
+ // / @param maxWriteQueueSize The maximum write queue size.
405
+ EndpointConfig (Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
406
+ int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
407
+ int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
408
+ int maxWriteQueueSize = -1 )
409
+ : transport(transport),
410
+ device (device),
411
+ ibMaxCqSize(ibMaxCqSize),
412
+ ibMaxCqPollNum(ibMaxCqPollNum),
413
+ ibMaxSendWr(ibMaxSendWr),
414
+ ibMaxWrPerSend(ibMaxWrPerSend),
415
+ maxWriteQueueSize(maxWriteQueueSize) {}
416
+ };
417
+
380
418
class Context ;
381
419
class Connection ;
420
+ class RegisteredMemory ;
421
+ class SemaphoreStub ;
422
+
423
+ // / One end of a connection.
424
+ class Endpoint {
425
+ public:
426
+ // / Constructor.
427
+ Endpoint () = default ;
428
+
429
+ // / Get the transport used.
430
+ // / @return The transport used.
431
+ Transport transport () const ;
432
+
433
+ // / Get the device used.
434
+ // / @return The device used.
435
+ const Device& device () const ;
436
+
437
+ // / Get the host hash.
438
+ // / @return The host hash.
439
+ uint64_t hostHash () const ;
440
+
441
+ // / Get the process ID hash.
442
+ // / @return The process ID hash.
443
+ uint64_t pidHash () const ;
444
+
445
+ // / Get the maximum write queue size.
446
+ // / @return The maximum number of write requests that can be queued.
447
+ int maxWriteQueueSize () const ;
448
+
449
+ // / Serialize the Endpoint object to a vector of characters.
450
+ // / @return A vector of characters representing the serialized Endpoint object.
451
+ std::vector<char > serialize () const ;
452
+
453
+ // / Deserialize an Endpoint object from a vector of characters.
454
+ // / @param data A vector of characters representing a serialized Endpoint object.
455
+ // / @return A deserialized Endpoint object.
456
+ static Endpoint deserialize (const std::vector<char >& data);
457
+
458
+ private:
459
+ struct Impl ;
460
+ Endpoint (std::shared_ptr<Impl> pimpl);
461
+ std::shared_ptr<Impl> pimpl_;
462
+
463
+ friend class Context ;
464
+ friend class Connection ;
465
+ };
466
+
467
+ // / Context for communication. This provides a low-level interface for forming connections in use-cases
468
+ // / where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
469
+ // / connections. Correct use of this class requires external synchronization when finalizing connections with the
470
+ // / connect() method.
471
+ // /
472
+ // / As an example, a client-server scenario where the server will write to the client might proceed as follows:
473
+ // / 1. The client creates an endpoint with createEndpoint() and sends it to the server.
474
+ // / 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
475
+ // / client, and creates a connection with connect().
476
+ // / 3. The client receives the server endpoint, creates a connection with connect() and sends a
477
+ // / RegisteredMemory to the server.
478
+ // / 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
479
+ // / The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
480
+ // / write to the RegisteredMemory before the connection is established.
481
+ // /
482
+ // / While some transports may have more relaxed implementation behavior, this should not be relied upon.
483
+ class Context : public std ::enable_shared_from_this<Context> {
484
+ public:
485
+ // / Create a new Context instance.
486
+ static std::shared_ptr<Context> create () { return std::shared_ptr<Context>(new Context ()); }
487
+
488
+ // / Destructor.
489
+ ~Context ();
490
+
491
+ // / Register a region of GPU memory for use in this context.
492
+ // /
493
+ // / @param ptr Base pointer to the memory.
494
+ // / @param size Size of the memory region in bytes.
495
+ // / @param transports Transport flags.
496
+ // / @return A RegisteredMemory object representing the registered memory region.
497
+ RegisteredMemory registerMemory (void * ptr, size_t size, TransportFlags transports);
498
+
499
+ // / Create an endpoint for establishing connections.
500
+ // /
501
+ // / @param config The configuration for the endpoint.
502
+ // / @return The newly created endpoint.
503
+ Endpoint createEndpoint (EndpointConfig config);
504
+
505
+ // / Establish a connection between two endpoints. While this method immediately returns a connection object, the
506
+ // / connection is only safe to use after the corresponding connection on the remote endpoint has been established.
507
+ // / This method must be called on both endpoints to establish a connection.
508
+ // /
509
+ // / @param localEndpoint The local endpoint.
510
+ // / @param remoteEndpoint The remote endpoint.
511
+ // / @return A shared pointer to the connection.
512
+ std::shared_ptr<Connection> connect (const Endpoint& localEndpoint, const Endpoint& remoteEndpoint);
513
+
514
+ private:
515
+ Context ();
516
+
517
+ struct Impl ;
518
+ std::unique_ptr<Impl> pimpl_;
519
+
520
+ friend class Endpoint ;
521
+ friend class Connection ;
522
+ friend class RegisteredMemory ;
523
+ };
382
524
383
525
// / Block of memory that has been registered to a Context.
384
526
// / RegisteredMemory does not own the memory it points to, but it provides a way to transfer metadata about the memory
@@ -426,50 +568,6 @@ class RegisteredMemory {
426
568
friend class SemaphoreStub ;
427
569
};
428
570
429
- // / One end of a connection.
430
- class Endpoint {
431
- public:
432
- // / Constructor.
433
- Endpoint () = default ;
434
-
435
- // / Get the transport used.
436
- // / @return The transport used.
437
- Transport transport () const ;
438
-
439
- // / Get the device used.
440
- // / @return The device used.
441
- const Device& device () const ;
442
-
443
- // / Get the host hash.
444
- // / @return The host hash.
445
- uint64_t hostHash () const ;
446
-
447
- // / Get the process ID hash.
448
- // / @return The process ID hash.
449
- uint64_t pidHash () const ;
450
-
451
- // / Get the maximum write queue size.
452
- // / @return The maximum number of write requests that can be queued.
453
- int maxWriteQueueSize () const ;
454
-
455
- // / Serialize the Endpoint object to a vector of characters.
456
- // / @return A vector of characters representing the serialized Endpoint object.
457
- std::vector<char > serialize () const ;
458
-
459
- // / Deserialize an Endpoint object from a vector of characters.
460
- // / @param data A vector of characters representing a serialized Endpoint object.
461
- // / @return A deserialized Endpoint object.
462
- static Endpoint deserialize (const std::vector<char >& data);
463
-
464
- private:
465
- struct Impl ;
466
- Endpoint (std::shared_ptr<Impl> pimpl);
467
- std::shared_ptr<Impl> pimpl_;
468
-
469
- friend class Context ;
470
- friend class Connection ;
471
- };
472
-
473
571
// / Connection between two processes.
474
572
class Connection {
475
573
public:
@@ -524,108 +622,15 @@ class Connection {
524
622
int getMaxWriteQueueSize () const ;
525
623
526
624
protected:
527
- static std::shared_ptr<RegisteredMemory::Impl> getImpl (RegisteredMemory& memory);
528
- static std::shared_ptr<Endpoint::Impl> getImpl (Endpoint& memory);
625
+ static const Endpoint::Impl& getImpl (const Endpoint& endpoint);
626
+ static const RegisteredMemory::Impl& getImpl (const RegisteredMemory& memory);
627
+ static Context::Impl& getImpl (Context& context);
529
628
530
629
std::shared_ptr<Context> context_;
531
630
Endpoint localEndpoint_;
532
631
int maxWriteQueueSize_;
533
632
};
534
633
535
- // / Used to configure an endpoint.
536
- struct EndpointConfig {
537
- static const int DefaultMaxCqSize = 1024 ;
538
- static const int DefaultMaxCqPollNum = 1 ;
539
- static const int DefaultMaxSendWr = 8192 ;
540
- static const int DefaultMaxWrPerSend = 64 ;
541
-
542
- Transport transport;
543
- Device device;
544
- int ibMaxCqSize;
545
- int ibMaxCqPollNum;
546
- int ibMaxSendWr;
547
- int ibMaxWrPerSend;
548
- int maxWriteQueueSize;
549
-
550
- // / Constructor that takes a transport and sets the other fields to their default values.
551
- // /
552
- // / @param transport The transport to use.
553
- // / @param device The device to use.
554
- // / @param ibMaxCqSize The maximum completion queue size.
555
- // / @param ibMaxCqPollNum The maximum completion queue poll number.
556
- // / @param ibMaxSendWr The maximum send work requests.
557
- // / @param ibMaxWrPerSend The maximum work requests per send.
558
- // / @param maxWriteQueueSize The maximum write queue size.
559
- EndpointConfig (Transport transport = Transport::Unknown, Device device = DeviceType::GPU,
560
- int ibMaxCqSize = DefaultMaxCqSize, int ibMaxCqPollNum = DefaultMaxCqPollNum,
561
- int ibMaxSendWr = DefaultMaxSendWr, int ibMaxWrPerSend = DefaultMaxWrPerSend,
562
- int maxWriteQueueSize = -1 )
563
- : transport(transport),
564
- device (device),
565
- ibMaxCqSize(ibMaxCqSize),
566
- ibMaxCqPollNum(ibMaxCqPollNum),
567
- ibMaxSendWr(ibMaxSendWr),
568
- ibMaxWrPerSend(ibMaxWrPerSend),
569
- maxWriteQueueSize(maxWriteQueueSize) {}
570
- };
571
-
572
- // / Context for communication. This provides a low-level interface for forming connections in use-cases
573
- // / where the process group abstraction offered by Communicator is not suitable, e.g., ephemeral client-server
574
- // / connections. Correct use of this class requires external synchronization when finalizing connections with the
575
- // / connect() method.
576
- // /
577
- // / As an example, a client-server scenario where the server will write to the client might proceed as follows:
578
- // / 1. The client creates an endpoint with createEndpoint() and sends it to the server.
579
- // / 2. The server receives the client endpoint, creates its own endpoint with createEndpoint(), sends it to the
580
- // / client, and creates a connection with connect().
581
- // / 3. The client receives the server endpoint, creates a connection with connect() and sends a
582
- // / RegisteredMemory to the server.
583
- // / 4. The server receives the RegisteredMemory and writes to it using the previously created connection.
584
- // / The client waiting to create a connection before sending the RegisteredMemory ensures that the server cannot
585
- // / write to the RegisteredMemory before the connection is established.
586
- // /
587
- // / While some transports may have more relaxed implementation behavior, this should not be relied upon.
588
- class Context : public std ::enable_shared_from_this<Context> {
589
- public:
590
- // / Create a new Context instance.
591
- static std::shared_ptr<Context> create () { return std::shared_ptr<Context>(new Context ()); }
592
-
593
- // / Destructor.
594
- ~Context ();
595
-
596
- // / Register a region of GPU memory for use in this context.
597
- // /
598
- // / @param ptr Base pointer to the memory.
599
- // / @param size Size of the memory region in bytes.
600
- // / @param transports Transport flags.
601
- // / @return A RegisteredMemory object representing the registered memory region.
602
- RegisteredMemory registerMemory (void * ptr, size_t size, TransportFlags transports);
603
-
604
- // / Create an endpoint for establishing connections.
605
- // /
606
- // / @param config The configuration for the endpoint.
607
- // / @return The newly created endpoint.
608
- Endpoint createEndpoint (EndpointConfig config);
609
-
610
- // / Establish a connection between two endpoints. While this method immediately returns a connection object, the
611
- // / connection is only safe to use after the corresponding connection on the remote endpoint has been established.
612
- // / This method must be called on both endpoints to establish a connection.
613
- // /
614
- // / @param localEndpoint The local endpoint.
615
- // / @param remoteEndpoint The remote endpoint.
616
- // / @return A shared pointer to the connection.
617
- std::shared_ptr<Connection> connect (Endpoint localEndpoint, Endpoint remoteEndpoint);
618
-
619
- private:
620
- Context ();
621
-
622
- struct Impl ;
623
- std::unique_ptr<Impl> pimpl_;
624
-
625
- friend class RegisteredMemory ;
626
- friend class Endpoint ;
627
- };
628
-
629
634
// / SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
630
635
class SemaphoreStub {
631
636
public:
0 commit comments