@@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7070        1ul  << (PML_UCX_TAG_BITS  -  1 ),
7171        1ul  << (PML_UCX_CONTEXT_BITS ),
7272    },
73-     NULL ,
74-     NULL 
73+     NULL ,    /* ucp_context */ 
74+     NULL      /* ucp_worker */ 
7575};
7676
7777static  int  mca_pml_ucx_send_worker_address (void )
@@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
116116
117117int  mca_pml_ucx_open (void )
118118{
119+     ucp_context_attr_t  attr ;
119120    ucp_params_t  params ;
120121    ucp_config_t  * config ;
121122    ucs_status_t  status ;
@@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
128129        return  OMPI_ERROR ;
129130    }
130131
132+     /* Initialize UCX context */ 
133+     params .field_mask       =  UCP_PARAM_FIELD_FEATURES  |
134+                              UCP_PARAM_FIELD_REQUEST_SIZE  |
135+                              UCP_PARAM_FIELD_REQUEST_INIT  |
136+                              UCP_PARAM_FIELD_REQUEST_CLEANUP  |
137+                              UCP_PARAM_FIELD_TAG_SENDER_MASK ;
131138    params .features         =  UCP_FEATURE_TAG ;
132139    params .request_size     =  sizeof (ompi_request_t );
133140    params .request_init     =  mca_pml_ucx_request_init ;
134141    params .request_cleanup  =  mca_pml_ucx_request_cleanup ;
142+     params .tag_sender_mask  =  PML_UCX_SPECIFIC_SOURCE_MASK ;
135143
136144    status  =  ucp_init (& params , config , & ompi_pml_ucx .ucp_context );
137145    ucp_config_release (config );
@@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
140148        return  OMPI_ERROR ;
141149    }
142150
151+     /* Query UCX attributes */ 
152+     attr .field_mask         =  UCP_ATTR_FIELD_REQUEST_SIZE ;
153+     status  =  ucp_context_query (ompi_pml_ucx .ucp_context , & attr );
154+     if  (UCS_OK  !=  status ) {
155+         ucp_cleanup (ompi_pml_ucx .ucp_context );
156+         ompi_pml_ucx .ucp_context  =  NULL ;
157+         return  OMPI_ERROR ;
158+     }
159+ 
160+     ompi_pml_ucx .request_size  =  attr .request_size ;
161+ 
143162    return  OMPI_SUCCESS ;
144163}
145164
@@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)
163182
164183    /* TODO check MPI thread mode */ 
165184    status  =  ucp_worker_create (ompi_pml_ucx .ucp_context , UCS_THREAD_MODE_SINGLE ,
166-                               & ompi_pml_ucx .ucp_worker );
185+                                 & ompi_pml_ucx .ucp_worker );
167186    if  (UCS_OK  !=  status ) {
168187        return  OMPI_ERROR ;
169188    }
@@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
252271{
253272    ucp_address_t  * address ;
254273    ucs_status_t  status ;
274+     ompi_proc_t  * proc ;
255275    size_t  addrlen ;
256276    ucp_ep_h  ep ;
257277    size_t  i ;
@@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
264284    }
265285
266286    for  (i  =  0 ; i  <  nprocs ; ++ i ) {
267-         ret  =  mca_pml_ucx_recv_worker_address (procs [i ], & address , & addrlen );
287+         proc  =  procs [(i  +  OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
288+ 
289+         ret  =  mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
268290        if  (ret  <  0 ) {
269-             PML_UCX_ERROR ("Failed to receive worker address from proc: %d" , procs [i ]-> super .proc_name .vpid );
291+             PML_UCX_ERROR ("Failed to receive worker address from proc: %d" ,
292+                           proc -> super .proc_name .vpid );
270293            return  ret ;
271294        }
272295
273-         if  (procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
274-             PML_UCX_VERBOSE (3 , "already connected to proc. %d" , procs [ i ] -> super .proc_name .vpid );
296+         if  (proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
297+             PML_UCX_VERBOSE (3 , "already connected to proc. %d" , proc -> super .proc_name .vpid );
275298            continue ;
276299        }
277300
278-         PML_UCX_VERBOSE (2 , "connecting to proc. %d" , procs [ i ] -> super .proc_name .vpid );
301+         PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> super .proc_name .vpid );
279302        status  =  ucp_ep_create (ompi_pml_ucx .ucp_worker , address , & ep );
280303        free (address );
281304
282305        if  (UCS_OK  !=  status ) {
283-             PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , procs [ i ] -> super .proc_name .vpid ,
284-                                                                 ucs_status_string (status ));
306+             PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc -> super .proc_name .vpid ,
307+                           ucs_status_string (status ));
285308            return  OMPI_ERROR ;
286309        }
287310
288-         procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] =  ep ;
311+         proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] =  ep ;
289312    }
290313
291314    return  OMPI_SUCCESS ;
292315}
293316
317+ static  void  mca_pml_ucx_waitall (void  * * reqs , size_t  * count_p )
318+ {
319+     ucs_status_t  status ;
320+     size_t  i ;
321+ 
322+     PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
323+     for  (i  =  0 ; i  <  * count_p ; ++ i ) {
324+         do  {
325+             opal_progress ();
326+             status  =  ucp_request_test (reqs [i ], NULL );
327+         } while  (status  ==  UCS_INPROGRESS );
328+         if  (status  !=  UCS_OK ) {
329+             PML_UCX_ERROR ("disconnect request failed: %s" ,
330+                           ucs_status_string (status ));
331+         }
332+         ucp_request_release (reqs [i ]);
333+         reqs [i ] =  NULL ;
334+     }
335+ 
336+     * count_p  =  0 ;
337+ }
338+ 
294339int  mca_pml_ucx_del_procs (struct  ompi_proc_t  * * procs , size_t  nprocs )
295340{
341+     ompi_proc_t  * proc ;
342+     size_t  num_reqs , max_reqs ;
343+     void  * dreq , * * dreqs ;
296344    ucp_ep_h  ep ;
297345    size_t  i ;
298346
347+     max_reqs  =  ompi_pml_ucx .num_disconnect ;
348+     if  (max_reqs  >  nprocs ) {
349+         max_reqs  =  nprocs ;
350+     }
351+ 
352+     dreqs  =  malloc (sizeof (* dreqs ) *  max_reqs );
353+     if  (dreqs  ==  NULL ) {
354+         return  OMPI_ERR_OUT_OF_RESOURCE ;
355+     }
356+ 
357+     num_reqs  =  0 ;
358+ 
299359    for  (i  =  0 ; i  <  nprocs ; ++ i ) {
300-         PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , procs [i ]-> super .proc_name .vpid );
301-         ep  =  procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
302-         if  (ep  !=  NULL ) {
303-             ucp_ep_destroy (ep );
360+         proc  =  procs [(i  +  OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
361+         ep    =  proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
362+         if  (ep  ==  NULL ) {
363+             continue ;
364+         }
365+ 
366+         PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , proc -> super .proc_name .vpid );
367+         dreq  =  ucp_disconnect_nb (ep );
368+         if  (dreq  !=  NULL ) {
369+             if  (UCS_PTR_IS_ERR (dreq )) {
370+                 PML_UCX_ERROR ("ucp_disconnect_nb(%d) failed: %s" ,
371+                               proc -> super .proc_name .vpid ,
372+                               ucs_status_string (UCS_PTR_STATUS (dreq )));
373+             } else  {
374+                 dreqs [num_reqs ++ ] =  dreq ;
375+             }
376+         }
377+ 
378+         proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] =  NULL ;
379+ 
380+         if  (num_reqs  >= ompi_pml_ucx .num_disconnect ) {
381+             mca_pml_ucx_waitall (dreqs , & num_reqs );
304382        }
305-         procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] =  NULL ;
306383    }
384+ 
385+     mca_pml_ucx_waitall (dreqs , & num_reqs );
386+     free (dreqs );
387+ 
307388    opal_pmix .fence (NULL , 0 );
389+ 
308390    return  OMPI_SUCCESS ;
309391}
310392
@@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
321403
322404int  mca_pml_ucx_progress (void )
323405{
324-     static  int  inprogress  =  0 ;
325-     if  (inprogress  !=  0 ) {
326-         return  0 ;
327-     }
328- 
329-     ++ inprogress ;
330406    ucp_worker_progress (ompi_pml_ucx .ucp_worker );
331-     -- inprogress ;
332407    return  OMPI_SUCCESS ;
333408}
334409
@@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
393468    return  OMPI_SUCCESS ;
394469}
395470
396- static  void 
397- mca_pml_ucx_blocking_recv_completion (void  * request , ucs_status_t  status ,
398-                                      ucp_tag_recv_info_t  * info )
399- {
400-     ompi_request_t  * req  =  request ;
401- 
402-     PML_UCX_VERBOSE (8 , "blocking receive request %p completed with status %s tag %" PRIx64 " len %zu" ,
403-                     (void * )req , ucs_status_string (status ), info -> sender_tag ,
404-                     info -> length );
405- 
406-     mca_pml_ucx_set_recv_status (& req -> req_status , status , info );
407-     PML_UCX_ASSERT ( !(REQUEST_COMPLETE (req )));
408-     ompi_request_complete (req ,true);
409- }
410- 
411471int  mca_pml_ucx_recv (void  * buf , size_t  count , ompi_datatype_t  * datatype , int  src ,
412472                     int  tag , struct  ompi_communicator_t *  comm ,
413473                     ompi_status_public_t *  mpi_status )
414474{
415475    ucp_tag_t  ucp_tag , ucp_tag_mask ;
416-     ompi_request_t  * req ;
476+     ucp_tag_recv_info_t  info ;
477+     ucs_status_t  status ;
478+     void  * req ;
417479
418480    PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
419481
420482    PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
421-     req  =  (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker , buf , count ,
422-                                            mca_pml_ucx_get_datatype (datatype ),
423-                                            ucp_tag , ucp_tag_mask ,
424-                                            mca_pml_ucx_blocking_recv_completion );
425-     if  (UCS_PTR_IS_ERR (req )) {
426-         PML_UCX_ERROR ("ucx recv failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
427-         return  OMPI_ERROR ;
428-     }
483+     req  =  alloca (ompi_pml_ucx .request_size ) +  ompi_pml_ucx .request_size ;
484+     status  =  ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
485+                               mca_pml_ucx_get_datatype (datatype ),
486+                               ucp_tag , ucp_tag_mask , req );
429487
430488    ucp_worker_progress (ompi_pml_ucx .ucp_worker );
431-     while  ( !REQUEST_COMPLETE (req ) ) {
489+     for  (;;) {
490+         status  =  ucp_request_test (req , & info );
491+         if  (status  !=  UCS_INPROGRESS ) {
492+             mca_pml_ucx_set_recv_status_safe (mpi_status , status , & info );
493+             return  OMPI_SUCCESS ;
494+         }
432495        opal_progress ();
433496    }
434- 
435-     if  (mpi_status  !=  MPI_STATUS_IGNORE ) {
436-         * mpi_status  =  req -> req_status ;
437-     }
438- 
439-     req -> req_complete  =  REQUEST_PENDING ;
440-     ucp_request_release (req );
441-     return  OMPI_SUCCESS ;
442497}
443498
444499static  inline  const  char  * mca_pml_ucx_send_mode_name (mca_pml_base_send_mode_t  mode )
@@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
583638        * matched  =  1 ;
584639        mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
585640    } else  {
641+         opal_progress ();
586642        * matched  =  0 ;
587643    }
588644    return  OMPI_SUCCESS ;
@@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
628684        PML_UCX_VERBOSE (8 , "got message %p (%p)" , (void * )* message , (void * )ucp_msg );
629685        * matched          =  1 ;
630686        mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
631-     } else  if  (UCS_PTR_STATUS (ucp_msg ) ==  UCS_ERR_NO_MESSAGE ) {
687+     } else   {
688+         opal_progress ();
632689        * matched  =  0 ;
633690    }
634691    return  OMPI_SUCCESS ;
0 commit comments