@@ -207,6 +207,42 @@ int mca_pml_ucx_cleanup(void)
207207 return OMPI_SUCCESS ;
208208}
209209
210+ ucp_ep_h mca_pml_ucx_add_proc (ompi_communicator_t * comm , int dst )
211+ {
212+ ucp_address_t * address ;
213+ ucs_status_t status ;
214+ size_t addrlen ;
215+ ucp_ep_h ep ;
216+ int ret ;
217+
218+ ompi_proc_t * proc = ompi_comm_peer_lookup (comm , 0 );
219+ ompi_proc_t * proc_peer = ompi_comm_peer_lookup (comm , dst );
220+
221+ /* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */
222+ if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected ("ucx" ,
223+ & proc ,
224+ dst ))) {
225+ return NULL ;
226+ }
227+
228+ ret = mca_pml_ucx_recv_worker_address (proc_peer , & address , & addrlen );
229+ if (ret < 0 ) {
230+ return NULL ;
231+ }
232+
233+ PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc_peer -> super .proc_name .vpid );
234+ status = ucp_ep_create (ompi_pml_ucx .ucp_worker , address , & ep );
235+ free (address );
236+ if (UCS_OK != status ) {
237+ PML_UCX_ERROR ("Failed to connect" );
238+ return NULL ;
239+ }
240+
241+ proc_peer -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
242+
243+ return ep ;
244+ }
245+
210246int mca_pml_ucx_add_procs (struct ompi_proc_t * * procs , size_t nprocs )
211247{
212248 ucp_address_t * address ;
@@ -426,7 +462,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
426462 struct ompi_request_t * * request )
427463{
428464 mca_pml_ucx_persistent_request_t * req ;
429-
465+ ucp_ep_h ep ;
430466
431467 req = (mca_pml_ucx_persistent_request_t * )PML_UCX_FREELIST_GET (& ompi_pml_ucx .persistent_reqs );
432468 if (req == NULL ) {
@@ -436,14 +472,20 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
436472 PML_UCX_TRACE_SEND ("isend_init request *%p=%p" , buf , count , datatype , dst ,
437473 tag , mode , comm , (void * )request , (void * )req )
438474
475+ ep = mca_pml_ucx_get_ep (comm , dst );
476+ if (OPAL_UNLIKELY (NULL == ep )) {
477+ PML_UCX_ERROR ("Failed to get ep" );
478+ return OMPI_ERROR ;
479+ }
480+
439481 req -> ompi .req_state = OMPI_REQUEST_INACTIVE ;
440482 req -> flags = MCA_PML_UCX_REQUEST_FLAG_SEND ;
441483 req -> buffer = (void * )buf ;
442484 req -> count = count ;
443485 req -> datatype = mca_pml_ucx_get_datatype (datatype );
444486 req -> tag = PML_UCX_MAKE_SEND_TAG (tag , comm );
445487 req -> send .mode = mode ;
446- req -> send .ep = mca_pml_ucx_get_ep ( comm , dst ) ;
488+ req -> send .ep = ep ;
447489
448490 * request = & req -> ompi ;
449491 return OMPI_SUCCESS ;
@@ -455,13 +497,20 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
455497 struct ompi_request_t * * request )
456498{
457499 ompi_request_t * req ;
500+ ucp_ep_h ep ;
458501
459502 PML_UCX_TRACE_SEND ("isend request *%p" , buf , count , datatype , dst , tag , mode ,
460503 comm , (void * )request )
461504
462505 /* TODO special care to sync/buffered send */
463506
464- req = (ompi_request_t * )ucp_tag_send_nb (mca_pml_ucx_get_ep (comm , dst ), buf , count ,
507+ ep = mca_pml_ucx_get_ep (comm , dst );
508+ if (OPAL_UNLIKELY (NULL == ep )) {
509+ PML_UCX_ERROR ("Failed to get ep" );
510+ return OMPI_ERROR ;
511+ }
512+
513+ req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
465514 mca_pml_ucx_get_datatype (datatype ),
466515 PML_UCX_MAKE_SEND_TAG (tag , comm ),
467516 mca_pml_ucx_send_completion );
@@ -484,12 +533,19 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
484533 struct ompi_communicator_t * comm )
485534{
486535 ompi_request_t * req ;
536+ ucp_ep_h ep ;
487537
488538 PML_UCX_TRACE_SEND ("%s" , buf , count , datatype , dst , tag , mode , comm , "send" );
489539
490540 /* TODO special care to sync/buffered send */
491541
492- req = (ompi_request_t * )ucp_tag_send_nb (mca_pml_ucx_get_ep (comm , dst ), buf , count ,
542+ ep = mca_pml_ucx_get_ep (comm , dst );
543+ if (OPAL_UNLIKELY (NULL == ep )) {
544+ PML_UCX_ERROR ("Failed to get ep" );
545+ return OMPI_ERROR ;
546+ }
547+
548+ req = (ompi_request_t * )ucp_tag_send_nb (ep , buf , count ,
493549 mca_pml_ucx_get_datatype (datatype ),
494550 PML_UCX_MAKE_SEND_TAG (tag , comm ),
495551 mca_pml_ucx_send_completion );
0 commit comments