Skip to content

Commit c6ad3a8

Browse files
committed
Convert listen_comm function pointers to virtual methods
Replace accept() and close() function pointers in nccl_net_ofi_listen_comm with virtual methods. This eliminates manual function pointer assignment and provides compile-time verification that all implementations provide these methods. Before: listen_comm->accept(listen_comm, &recv_comm); After: listen_comm->accept(&recv_comm); Signed-off-by: Bibrak Qamar Chandio <bibracha@amazon.com>
1 parent 5f4202f commit c6ad3a8

File tree

12 files changed

+97
-118
lines changed

12 files changed

+97
-118
lines changed

include/gin/nccl_ofi_gin.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,18 @@ inline nccl_ofi_device_copy &get_device_copy()
2929
class nccl_ofi_gin_listen_comm {
3030
private:
3131
nccl_net_ofi_ep_t *ep;
32-
nccl_net_ofi_listen_comm_t *l_comm;
32+
nccl_net_ofi_listen_comm *l_comm;
3333

3434
public:
3535
nccl_ofi_gin_listen_comm(int dev_arg, nccl_net_ofi_ep_t *ep_arg,
36-
nccl_net_ofi_listen_comm_t *l_comm_arg)
36+
nccl_net_ofi_listen_comm *l_comm_arg)
3737
: ep(ep_arg), l_comm(l_comm_arg)
3838
{
3939
}
4040

4141
~nccl_ofi_gin_listen_comm()
4242
{
43-
int ret = l_comm->close(l_comm);
43+
int ret = l_comm->close();
4444
if (ret != 0) {
4545
NCCL_OFI_WARN("GIN: Unable to close net listen comm: %d", ret);
4646
}

include/nccl_ofi.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,11 @@ class nccl_net_ofi_ep_t;
111111
class nccl_net_ofi_plugin_t;
112112

113113
struct nccl_net_ofi_comm;
114-
struct nccl_net_ofi_listen_comm;
114+
class nccl_net_ofi_listen_comm;
115115
struct nccl_net_ofi_send_comm;
116116
struct nccl_net_ofi_recv_comm;
117117

118118
typedef struct nccl_net_ofi_comm nccl_net_ofi_comm_t;
119-
typedef struct nccl_net_ofi_listen_comm nccl_net_ofi_listen_comm_t;
120119
typedef struct nccl_net_ofi_send_comm nccl_net_ofi_send_comm_t;
121120
typedef struct nccl_net_ofi_recv_comm nccl_net_ofi_recv_comm_t;
122121

@@ -682,7 +681,7 @@ class nccl_net_ofi_ep_t {
682681
* handle is set to COMM_CREATE_START.
683682
*/
684683
virtual int listen(nccl_net_ofi_conn_handle_t *handle,
685-
nccl_net_ofi_listen_comm_t **listen_comm) = 0;
684+
nccl_net_ofi_listen_comm **listen_comm) = 0;
686685

687686
/* Create a connection to a process that has called
688687
* listen().
@@ -864,12 +863,14 @@ struct nccl_net_ofi_comm {
864863
/**
865864
* Listen Communicator - Communicator for a listen/accept pairing
866865
*/
867-
struct nccl_net_ofi_listen_comm {
866+
class nccl_net_ofi_listen_comm {
867+
public:
868868
nccl_net_ofi_comm_t base;
869869

870-
int (*accept)(nccl_net_ofi_listen_comm_t *listen_comm,
871-
nccl_net_ofi_recv_comm_t **recv_comm);
872-
int (*close)(nccl_net_ofi_listen_comm_t *listen_comm);
870+
virtual ~nccl_net_ofi_listen_comm() = default;
871+
872+
virtual int accept(nccl_net_ofi_recv_comm_t **recv_comm) = 0;
873+
virtual int close() = 0;
873874
};
874875

875876
struct nccl_net_ofi_send_comm {

include/nccl_ofi_rdma.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -684,11 +684,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
684684
uint64_t remote_mr_key[MAX_NUM_RAILS];
685685
} nccl_net_ofi_rdma_recv_comm_t;
686686

687-
typedef struct nccl_net_ofi_rdma_listen_comm {
688-
/* This base listen communicator must be the first member of
689-
* this struct. This allows casting between pointers of this
690-
* struct and its base struct. */
691-
nccl_net_ofi_listen_comm_t base;
687+
class nccl_net_ofi_rdma_listen_comm : public nccl_net_ofi_listen_comm {
688+
public:
689+
int accept(nccl_net_ofi_recv_comm_t **recv_comm) override;
690+
int close() override;
692691

693692
/* Associated listener from connection manager */
694693
nccl_ofi_cm_listener *listener;
@@ -698,7 +697,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm {
698697

699698
/* Stage of connection establishment on listen side */
700699
nccl_ofi_comm_stage_t stage;
701-
} nccl_net_ofi_rdma_listen_comm_t;
700+
};
702701

703702

704703
class nccl_net_ofi_rdma_domain_rail_t {
@@ -1035,7 +1034,7 @@ class nccl_net_ofi_rdma_ep_t : public nccl_net_ofi_ep_t {
10351034
int cleanup_resources() override;
10361035

10371036
int listen(nccl_net_ofi_conn_handle_t *handle,
1038-
nccl_net_ofi_listen_comm_t **listen_comm) override;
1037+
nccl_net_ofi_listen_comm **listen_comm) override;
10391038

10401039
/**
10411040
* @brief Execute the connect functionality from listen/connect/accept

include/nccl_ofi_sendrecv.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ class nccl_net_ofi_sendrecv_mr_handle_t : public nccl_net_ofi_mr_handle_t {
5050
ofi_mr_ptr mr;
5151
};
5252

53-
typedef struct nccl_net_ofi_sendrecv_listen_comm {
54-
/* This base listen communicator must be the first member of
55-
* this struct. This allows casting between pointers of this
56-
* struct and its base struct. */
57-
nccl_net_ofi_listen_comm_t base;
53+
class nccl_net_ofi_sendrecv_listen_comm : public nccl_net_ofi_listen_comm {
54+
public:
55+
int accept(nccl_net_ofi_recv_comm_t **recv_comm) override;
56+
int close() override;
5857

5958
struct fid_ep *local_ep;
6059
fi_addr_t local_ep_addr;
6160
/* Saves temporary state when creating receive communicator object */
6261
save_comm_state_t state;
6362

6463
nccl_ofi_cm_listener *listener;
65-
} nccl_net_ofi_sendrecv_listen_comm_t;
64+
};
65+
6666

6767
typedef struct nccl_net_ofi_sendrecv_send_comm {
6868
/* This base send communicator must be the first member of this
@@ -183,7 +183,7 @@ class nccl_net_ofi_sendrecv_ep_t : public nccl_net_ofi_ep_t {
183183
nccl_net_ofi_sendrecv_ep_t(nccl_net_ofi_sendrecv_domain_t *domain_arg);
184184

185185
int listen(nccl_net_ofi_conn_handle_t *handle,
186-
nccl_net_ofi_listen_comm_t **listen_comm) override;
186+
nccl_net_ofi_listen_comm **listen_comm) override;
187187

188188
int connect(nccl_net_ofi_conn_handle_t *handle,
189189
nccl_net_ofi_send_comm_t **send_comm,

src/gin/nccl_ofi_gin.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ int nccl_ofi_gin_listen_comm::connect(nccl_net_ofi_conn_handle_t *handles[], int
160160
}
161161
}
162162
if (r_comm == nullptr) {
163-
ret = l_comm->accept(l_comm, &r_comm);
163+
ret = l_comm->accept(&r_comm);
164164
if (ret != 0) {
165165
NCCL_OFI_WARN("Error in bootstrap ring accept: %d", ret);
166166
return ret;

src/gin/nccl_ofi_gin_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ static ncclResult_t nccl_ofi_gin_listen(void *ctx, int dev, void *handle, void *
149149

150150
nccl_net_ofi_ep_t *ep = device->get_ep(0, static_cast<long>(comm_id));
151151

152-
nccl_net_ofi_listen_comm_t *l_comm = nullptr;
152+
nccl_net_ofi_listen_comm *l_comm = nullptr;
153153
int ret = ep->listen(static_cast<nccl_net_ofi_conn_handle_t *>(handle), &l_comm);
154154
if (ret != 0) {
155155
NCCL_OFI_WARN("GIN: error listening on device %i.", dev);

src/nccl_ofi_api.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ ncclResult_t nccl_net_ofi_listen(int dev_id, void *handle, void **lComm,
147147
int ret = 0;
148148
nccl_net_ofi_device_t *device = nullptr;
149149
nccl_net_ofi_ep_t *ep = nullptr;
150-
nccl_net_ofi_listen_comm_t **listen_comm =
151-
reinterpret_cast<nccl_net_ofi_listen_comm_t **>(lComm);
150+
nccl_net_ofi_listen_comm **listen_comm =
151+
reinterpret_cast<nccl_net_ofi_listen_comm **>(lComm);
152152

153153
/* Validate plugin */
154154
if (OFI_UNLIKELY(plugin == nullptr)) {
@@ -308,13 +308,13 @@ ncclResult_t nccl_net_ofi_accept(void *lComm, void **rComm)
308308
}
309309

310310
/* Invoke listen communicator accept() function */
311-
nccl_net_ofi_listen_comm_t *listen_comm =
312-
reinterpret_cast<nccl_net_ofi_listen_comm_t *>(lComm);
311+
nccl_net_ofi_listen_comm *listen_comm =
312+
reinterpret_cast<nccl_net_ofi_listen_comm *>(lComm);
313313
nccl_net_ofi_recv_comm_t **recv_comm =
314314
reinterpret_cast<nccl_net_ofi_recv_comm_t **>(rComm);
315315
int ret = 0;
316316
try {
317-
ret = listen_comm->accept(listen_comm, recv_comm);
317+
ret = listen_comm->accept(recv_comm);
318318
}
319319
catch (const std::exception &e) {
320320
NCCL_OFI_WARN("Caught exception in plugin accept: %s", e.what());
@@ -642,8 +642,8 @@ ncclResult_t nccl_net_ofi_closeRecv(void *rComm)
642642

643643
ncclResult_t nccl_net_ofi_closeListen(void *lComm)
644644
{
645-
nccl_net_ofi_listen_comm_t *listen_comm =
646-
(nccl_net_ofi_listen_comm_t *)lComm;
645+
nccl_net_ofi_listen_comm *listen_comm =
646+
(nccl_net_ofi_listen_comm *)lComm;
647647

648648
/* neuron has a cleanup race between the atexit handler and *
649649
* calling close on all the communicators, so be more silent
@@ -659,7 +659,7 @@ ncclResult_t nccl_net_ofi_closeListen(void *lComm)
659659
return check_return(ncclInternalError);
660660
}
661661

662-
int ret = listen_comm->close(listen_comm);
662+
int ret = listen_comm->close();
663663
return nccl_net_ofi_retval_translate_impl(ret);
664664
}
665665

0 commit comments

Comments
 (0)