@@ -78,6 +78,7 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7878#define PML_UCX_REQ_ALLOCA () \
7979 ((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size);
8080
81+
8182static int mca_pml_ucx_send_worker_address (void )
8283{
8384 ucp_address_t * address ;
@@ -111,9 +112,10 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
111112
112113 * address_p = NULL ;
113114 OPAL_MODEX_RECV (ret , & mca_pml_ucx_component .pmlm_version , & proc -> super .proc_name ,
114- (void * * )address_p , addrlen_p );
115+ (void * * )address_p , addrlen_p );
115116 if (ret < 0 ) {
116- PML_UCX_ERROR ("Failed to receive EP address" );
117+ PML_UCX_ERROR ("Failed to receive UCX worker address: %s (%d)" ,
118+ opal_strerror (ret ), ret );
117119 }
118120 return ret ;
119121}
@@ -267,7 +269,7 @@ int mca_pml_ucx_cleanup(void)
267269 return OMPI_SUCCESS ;
268270}
269271
270- ucp_ep_h mca_pml_ucx_add_proc ( ompi_communicator_t * comm , int dst )
272+ static ucp_ep_h mca_pml_ucx_add_proc_common ( ompi_proc_t * proc )
271273{
272274 ucp_ep_params_t ep_params ;
273275 ucp_address_t * address ;
@@ -276,90 +278,91 @@ ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst)
276278 ucp_ep_h ep ;
277279 int ret ;
278280
279- ompi_proc_t * proc0 = ompi_comm_peer_lookup (comm , 0 );
280- ompi_proc_t * proc_peer = ompi_comm_peer_lookup (comm , dst );
281-
282- /* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */
283- if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected ("ucx" ,
284- & proc0 ,
285- dst ))) {
286- return NULL ;
287- }
288-
289- ret = mca_pml_ucx_recv_worker_address (proc_peer , & address , & addrlen );
281+ ret = mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
290282 if (ret < 0 ) {
291- PML_UCX_ERROR ("Failed to receive worker address from proc: %d" , proc_peer -> super .proc_name .vpid );
292283 return NULL ;
293284 }
294285
295- PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc_peer -> super .proc_name .vpid );
286+ PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> super .proc_name .vpid );
296287
297288 ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
298289 ep_params .address = address ;
299290
300291 status = ucp_ep_create (ompi_pml_ucx .ucp_worker , & ep_params , & ep );
301292 free (address );
302293 if (UCS_OK != status ) {
303- PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc_peer -> super .proc_name .vpid ,
304- ucs_status_string (status ));
294+ PML_UCX_ERROR ("ucp_ep_create(proc=%d) failed: %s" ,
295+ proc -> super .proc_name .vpid ,
296+ ucs_status_string (status ));
305297 return NULL ;
306298 }
307299
308- proc_peer -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
309-
300+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
310301 return ep ;
311302}
312303
304+ static ucp_ep_h mca_pml_ucx_add_proc (ompi_communicator_t * comm , int dst )
305+ {
306+ ompi_proc_t * proc0 = ompi_comm_peer_lookup (comm , 0 );
307+ ompi_proc_t * proc_peer = ompi_comm_peer_lookup (comm , dst );
308+ int ret ;
309+
310+ /* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */
311+ if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected ("ucx" ,
312+ & proc0 ,
313+ dst ))) {
314+ return NULL ;
315+ }
316+
317+ return mca_pml_ucx_add_proc_common (proc_peer );
318+ }
319+
313320int mca_pml_ucx_add_procs (struct ompi_proc_t * * procs , size_t nprocs )
314321{
315- ucp_ep_params_t ep_params ;
316- ucp_address_t * address ;
317- ucs_status_t status ;
318322 ompi_proc_t * proc ;
319- size_t addrlen ;
320323 ucp_ep_h ep ;
321324 size_t i ;
322325 int ret ;
323326
324327 if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected ("ucx" ,
325- procs ,
326- nprocs ))) {
328+ procs ,
329+ nprocs ))) {
327330 return ret ;
328331 }
329332
330333 for (i = 0 ; i < nprocs ; ++ i ) {
331334 proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
332-
333- ret = mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
334- if (ret < 0 ) {
335- PML_UCX_ERROR ("Failed to receive worker address from proc: %d" ,
336- proc -> super .proc_name .vpid );
337- return ret ;
338- }
339-
340- if (proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
341- PML_UCX_VERBOSE (3 , "already connected to proc. %d" , proc -> super .proc_name .vpid );
342- continue ;
335+ ep = mca_pml_ucx_add_proc_common (proc );
336+ if (ep == NULL ) {
337+ return OMPI_ERROR ;
343338 }
339+ }
344340
345- PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> super .proc_name .vpid );
341+ return OMPI_SUCCESS ;
342+ }
346343
347- ep_params .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS ;
348- ep_params .address = address ;
344+ static inline ucp_ep_h mca_pml_ucx_get_ep (ompi_communicator_t * comm , int rank )
345+ {
346+ ucp_ep_h ep ;
349347
350- status = ucp_ep_create (ompi_pml_ucx .ucp_worker , & ep_params , & ep );
351- free (address );
348+ ep = ompi_comm_peer_lookup (comm , rank )-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
349+ if (OPAL_LIKELY (ep != NULL )) {
350+ return ep ;
351+ }
352352
353- if (UCS_OK != status ) {
354- PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc -> super .proc_name .vpid ,
355- ucs_status_string (status ));
356- return OMPI_ERROR ;
357- }
353+ ep = mca_pml_ucx_add_proc (comm , rank );
354+ if (OPAL_LIKELY (ep != NULL )) {
355+ return ep ;
356+ }
358357
359- proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
358+ if (rank >= ompi_comm_size (comm )) {
359+ PML_UCX_ERROR ("Rank number (%d) is larger than communicator size (%d)" ,
360+ rank , ompi_comm_size (comm ));
361+ } else {
362+ PML_UCX_ERROR ("Failed to resolve UCX endpoint for rank %d" , rank );
360363 }
361364
362- return OMPI_SUCCESS ;
365+ return NULL ;
363366}
364367
365368static void mca_pml_ucx_waitall (void * * reqs , size_t * count_p )
@@ -581,7 +584,6 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
581584
582585 ep = mca_pml_ucx_get_ep (comm , dst );
583586 if (OPAL_UNLIKELY (NULL == ep )) {
584- PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
585587 return OMPI_ERROR ;
586588 }
587589
@@ -695,7 +697,6 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
695697
696698 ep = mca_pml_ucx_get_ep (comm , dst );
697699 if (OPAL_UNLIKELY (NULL == ep )) {
698- PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
699700 return OMPI_ERROR ;
700701 }
701702
@@ -779,7 +780,6 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
779780
780781 ep = mca_pml_ucx_get_ep (comm , dst );
781782 if (OPAL_UNLIKELY (NULL == ep )) {
782- PML_UCX_ERROR ("Failed to get ep for rank %d" , dst );
783783 return OMPI_ERROR ;
784784 }
785785
0 commit comments