@@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7070 1ul << (PML_UCX_TAG_BITS - 1 ),
7171 1ul << (PML_UCX_CONTEXT_BITS ),
7272 },
73- NULL ,
74- NULL
73+ NULL , /* ucp_context */
74+ NULL /* ucp_worker */
7575};
7676
7777static int mca_pml_ucx_send_worker_address (void )
@@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
116116
117117int mca_pml_ucx_open (void )
118118{
119+ ucp_context_attr_t attr ;
119120 ucp_params_t params ;
120121 ucp_config_t * config ;
121122 ucs_status_t status ;
@@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
128129 return OMPI_ERROR ;
129130 }
130131
132+ /* Initialize UCX context */
133+ params .field_mask = UCP_PARAM_FIELD_FEATURES |
134+ UCP_PARAM_FIELD_REQUEST_SIZE |
135+ UCP_PARAM_FIELD_REQUEST_INIT |
136+ UCP_PARAM_FIELD_REQUEST_CLEANUP |
137+ UCP_PARAM_FIELD_TAG_SENDER_MASK ;
131138 params .features = UCP_FEATURE_TAG ;
132139 params .request_size = sizeof (ompi_request_t );
133140 params .request_init = mca_pml_ucx_request_init ;
134141 params .request_cleanup = mca_pml_ucx_request_cleanup ;
142+ params .tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK ;
135143
136144 status = ucp_init (& params , config , & ompi_pml_ucx .ucp_context );
137145 ucp_config_release (config );
@@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
140148 return OMPI_ERROR ;
141149 }
142150
151+ /* Query UCX attributes */
152+ attr .field_mask = UCP_ATTR_FIELD_REQUEST_SIZE ;
153+ status = ucp_context_query (ompi_pml_ucx .ucp_context , & attr );
154+ if (UCS_OK != status ) {
155+ ucp_cleanup (ompi_pml_ucx .ucp_context );
156+ ompi_pml_ucx .ucp_context = NULL ;
157+ return OMPI_ERROR ;
158+ }
159+
160+ ompi_pml_ucx .request_size = attr .request_size ;
161+
143162 return OMPI_SUCCESS ;
144163}
145164
@@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)
163182
164183 /* TODO check MPI thread mode */
165184 status = ucp_worker_create (ompi_pml_ucx .ucp_context , UCS_THREAD_MODE_SINGLE ,
166- & ompi_pml_ucx .ucp_worker );
185+ & ompi_pml_ucx .ucp_worker );
167186 if (UCS_OK != status ) {
168187 return OMPI_ERROR ;
169188 }
@@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
252271{
253272 ucp_address_t * address ;
254273 ucs_status_t status ;
274+ ompi_proc_t * proc ;
255275 size_t addrlen ;
256276 ucp_ep_h ep ;
257277 size_t i ;
@@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
264284 }
265285
266286 for (i = 0 ; i < nprocs ; ++ i ) {
267- ret = mca_pml_ucx_recv_worker_address (procs [i ], & address , & addrlen );
287+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
288+
289+ ret = mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
268290 if (ret < 0 ) {
269- PML_UCX_ERROR ("Failed to receive worker address from proc: %d" , procs [i ]-> super .proc_name .vpid );
291+ PML_UCX_ERROR ("Failed to receive worker address from proc: %d" ,
292+ proc -> super .proc_name .vpid );
270293 return ret ;
271294 }
272295
273- if (procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
274- PML_UCX_VERBOSE (3 , "already connected to proc. %d" , procs [ i ] -> super .proc_name .vpid );
296+ if (proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
297+ PML_UCX_VERBOSE (3 , "already connected to proc. %d" , proc -> super .proc_name .vpid );
275298 continue ;
276299 }
277300
278- PML_UCX_VERBOSE (2 , "connecting to proc. %d" , procs [ i ] -> super .proc_name .vpid );
301+ PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> super .proc_name .vpid );
279302 status = ucp_ep_create (ompi_pml_ucx .ucp_worker , address , & ep );
280303 free (address );
281304
282305 if (UCS_OK != status ) {
283- PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , procs [ i ] -> super .proc_name .vpid ,
284- ucs_status_string (status ));
306+ PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc -> super .proc_name .vpid ,
307+ ucs_status_string (status ));
285308 return OMPI_ERROR ;
286309 }
287310
288- procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
311+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
289312 }
290313
291314 return OMPI_SUCCESS ;
292315}
293316
317+ static void mca_pml_ucx_waitall (void * * reqs , size_t * count_p )
318+ {
319+ ucs_status_t status ;
320+ size_t i ;
321+
322+ PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
323+ for (i = 0 ; i < * count_p ; ++ i ) {
324+ do {
325+ opal_progress ();
326+ status = ucp_request_test (reqs [i ], NULL );
327+ } while (status == UCS_INPROGRESS );
328+ if (status != UCS_OK ) {
329+ PML_UCX_ERROR ("disconnect request failed: %s" ,
330+ ucs_status_string (status ));
331+ }
332+ ucp_request_release (reqs [i ]);
333+ reqs [i ] = NULL ;
334+ }
335+
336+ * count_p = 0 ;
337+ }
338+
294339int mca_pml_ucx_del_procs (struct ompi_proc_t * * procs , size_t nprocs )
295340{
341+ ompi_proc_t * proc ;
342+ size_t num_reqs , max_reqs ;
343+ void * dreq , * * dreqs ;
296344 ucp_ep_h ep ;
297345 size_t i ;
298346
347+ max_reqs = ompi_pml_ucx .num_disconnect ;
348+ if (max_reqs > nprocs ) {
349+ max_reqs = nprocs ;
350+ }
351+
352+ dreqs = malloc (sizeof (* dreqs ) * max_reqs );
353+ if (dreqs == NULL ) {
354+ return OMPI_ERR_OUT_OF_RESOURCE ;
355+ }
356+
357+ num_reqs = 0 ;
358+
299359 for (i = 0 ; i < nprocs ; ++ i ) {
300- PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , procs [i ]-> super .proc_name .vpid );
301- ep = procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
302- if (ep != NULL ) {
303- ucp_ep_destroy (ep );
360+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
361+ ep = proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
362+ if (ep == NULL ) {
363+ continue ;
364+ }
365+
366+ PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , proc -> super .proc_name .vpid );
367+ dreq = ucp_disconnect_nb (ep );
368+ if (dreq != NULL ) {
369+ if (UCS_PTR_IS_ERR (dreq )) {
370+ PML_UCX_ERROR ("ucp_disconnect_nb(%d) failed: %s" ,
371+ proc -> super .proc_name .vpid ,
372+ ucs_status_string (UCS_PTR_STATUS (dreq )));
373+ } else {
374+ dreqs [num_reqs ++ ] = dreq ;
375+ }
376+ }
377+
378+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
379+
380+ if (num_reqs >= ompi_pml_ucx .num_disconnect ) {
381+ mca_pml_ucx_waitall (dreqs , & num_reqs );
304382 }
305- procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
306383 }
384+
385+ mca_pml_ucx_waitall (dreqs , & num_reqs );
386+ free (dreqs );
387+
307388 opal_pmix .fence (NULL , 0 );
389+
308390 return OMPI_SUCCESS ;
309391}
310392
@@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
321403
322404int mca_pml_ucx_progress (void )
323405{
324- static int inprogress = 0 ;
325- if (inprogress != 0 ) {
326- return 0 ;
327- }
328-
329- ++ inprogress ;
330406 ucp_worker_progress (ompi_pml_ucx .ucp_worker );
331- -- inprogress ;
332407 return OMPI_SUCCESS ;
333408}
334409
@@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
393468 return OMPI_SUCCESS ;
394469}
395470
396- static void
397- mca_pml_ucx_blocking_recv_completion (void * request , ucs_status_t status ,
398- ucp_tag_recv_info_t * info )
399- {
400- ompi_request_t * req = request ;
401-
402- PML_UCX_VERBOSE (8 , "blocking receive request %p completed with status %s tag %" PRIx64 " len %zu" ,
403- (void * )req , ucs_status_string (status ), info -> sender_tag ,
404- info -> length );
405-
406- mca_pml_ucx_set_recv_status (& req -> req_status , status , info );
407- PML_UCX_ASSERT ( !(REQUEST_COMPLETE (req )));
408- ompi_request_complete (req ,true);
409- }
410-
411471int mca_pml_ucx_recv (void * buf , size_t count , ompi_datatype_t * datatype , int src ,
412472 int tag , struct ompi_communicator_t * comm ,
413473 ompi_status_public_t * mpi_status )
414474{
415475 ucp_tag_t ucp_tag , ucp_tag_mask ;
416- ompi_request_t * req ;
476+ ucp_tag_recv_info_t info ;
477+ ucs_status_t status ;
478+ void * req ;
417479
418480 PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
419481
420482 PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
421- req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker , buf , count ,
422- mca_pml_ucx_get_datatype (datatype ),
423- ucp_tag , ucp_tag_mask ,
424- mca_pml_ucx_blocking_recv_completion );
425- if (UCS_PTR_IS_ERR (req )) {
426- PML_UCX_ERROR ("ucx recv failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
427- return OMPI_ERROR ;
428- }
483+ req = alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
484+ status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
485+ mca_pml_ucx_get_datatype (datatype ),
486+ ucp_tag , ucp_tag_mask , req );
429487
430488 ucp_worker_progress (ompi_pml_ucx .ucp_worker );
431- while ( !REQUEST_COMPLETE (req ) ) {
489+ for (;;) {
490+ status = ucp_request_test (req , & info );
491+ if (status != UCS_INPROGRESS ) {
492+ mca_pml_ucx_set_recv_status_safe (mpi_status , status , & info );
493+ return OMPI_SUCCESS ;
494+ }
432495 opal_progress ();
433496 }
434-
435- if (mpi_status != MPI_STATUS_IGNORE ) {
436- * mpi_status = req -> req_status ;
437- }
438-
439- req -> req_complete = REQUEST_PENDING ;
440- ucp_request_release (req );
441- return OMPI_SUCCESS ;
442497}
443498
444499static inline const char * mca_pml_ucx_send_mode_name (mca_pml_base_send_mode_t mode )
@@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
583638 * matched = 1 ;
584639 mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
585640 } else {
641+ opal_progress ();
586642 * matched = 0 ;
587643 }
588644 return OMPI_SUCCESS ;
@@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
628684 PML_UCX_VERBOSE (8 , "got message %p (%p)" , (void * )* message , (void * )ucp_msg );
629685 * matched = 1 ;
630686 mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
631- } else if (UCS_PTR_STATUS (ucp_msg ) == UCS_ERR_NO_MESSAGE ) {
687+ } else {
688+ opal_progress ();
632689 * matched = 0 ;
633690 }
634691 return OMPI_SUCCESS ;
0 commit comments