@@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
70
70
1ul << (PML_UCX_TAG_BITS - 1 ),
71
71
1ul << (PML_UCX_CONTEXT_BITS ),
72
72
},
73
- NULL ,
74
- NULL
73
+ NULL , /* ucp_context */
74
+ NULL /* ucp_worker */
75
75
};
76
76
77
77
static int mca_pml_ucx_send_worker_address (void )
@@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
116
116
117
117
int mca_pml_ucx_open (void )
118
118
{
119
+ ucp_context_attr_t attr ;
119
120
ucp_params_t params ;
120
121
ucp_config_t * config ;
121
122
ucs_status_t status ;
@@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
128
129
return OMPI_ERROR ;
129
130
}
130
131
132
+ /* Initialize UCX context */
133
+ params .field_mask = UCP_PARAM_FIELD_FEATURES |
134
+ UCP_PARAM_FIELD_REQUEST_SIZE |
135
+ UCP_PARAM_FIELD_REQUEST_INIT |
136
+ UCP_PARAM_FIELD_REQUEST_CLEANUP |
137
+ UCP_PARAM_FIELD_TAG_SENDER_MASK ;
131
138
params .features = UCP_FEATURE_TAG ;
132
139
params .request_size = sizeof (ompi_request_t );
133
140
params .request_init = mca_pml_ucx_request_init ;
134
141
params .request_cleanup = mca_pml_ucx_request_cleanup ;
142
+ params .tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK ;
135
143
136
144
status = ucp_init (& params , config , & ompi_pml_ucx .ucp_context );
137
145
ucp_config_release (config );
@@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
140
148
return OMPI_ERROR ;
141
149
}
142
150
151
+ /* Query UCX attributes */
152
+ attr .field_mask = UCP_ATTR_FIELD_REQUEST_SIZE ;
153
+ status = ucp_context_query (ompi_pml_ucx .ucp_context , & attr );
154
+ if (UCS_OK != status ) {
155
+ ucp_cleanup (ompi_pml_ucx .ucp_context );
156
+ ompi_pml_ucx .ucp_context = NULL ;
157
+ return OMPI_ERROR ;
158
+ }
159
+
160
+ ompi_pml_ucx .request_size = attr .request_size ;
161
+
143
162
return OMPI_SUCCESS ;
144
163
}
145
164
@@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)
163
182
164
183
/* TODO check MPI thread mode */
165
184
status = ucp_worker_create (ompi_pml_ucx .ucp_context , UCS_THREAD_MODE_SINGLE ,
166
- & ompi_pml_ucx .ucp_worker );
185
+ & ompi_pml_ucx .ucp_worker );
167
186
if (UCS_OK != status ) {
168
187
return OMPI_ERROR ;
169
188
}
@@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
252
271
{
253
272
ucp_address_t * address ;
254
273
ucs_status_t status ;
274
+ ompi_proc_t * proc ;
255
275
size_t addrlen ;
256
276
ucp_ep_h ep ;
257
277
size_t i ;
@@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
264
284
}
265
285
266
286
for (i = 0 ; i < nprocs ; ++ i ) {
267
- ret = mca_pml_ucx_recv_worker_address (procs [i ], & address , & addrlen );
287
+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
288
+
289
+ ret = mca_pml_ucx_recv_worker_address (proc , & address , & addrlen );
268
290
if (ret < 0 ) {
269
- PML_UCX_ERROR ("Failed to receive worker address from proc: %d" , procs [i ]-> super .proc_name .vpid );
291
+ PML_UCX_ERROR ("Failed to receive worker address from proc: %d" ,
292
+ proc -> super .proc_name .vpid );
270
293
return ret ;
271
294
}
272
295
273
- if (procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
274
- PML_UCX_VERBOSE (3 , "already connected to proc. %d" , procs [ i ] -> super .proc_name .vpid );
296
+ if (proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ]) {
297
+ PML_UCX_VERBOSE (3 , "already connected to proc. %d" , proc -> super .proc_name .vpid );
275
298
continue ;
276
299
}
277
300
278
- PML_UCX_VERBOSE (2 , "connecting to proc. %d" , procs [ i ] -> super .proc_name .vpid );
301
+ PML_UCX_VERBOSE (2 , "connecting to proc. %d" , proc -> super .proc_name .vpid );
279
302
status = ucp_ep_create (ompi_pml_ucx .ucp_worker , address , & ep );
280
303
free (address );
281
304
282
305
if (UCS_OK != status ) {
283
- PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , procs [ i ] -> super .proc_name .vpid ,
284
- ucs_status_string (status ));
306
+ PML_UCX_ERROR ("Failed to connect to proc: %d, %s" , proc -> super .proc_name .vpid ,
307
+ ucs_status_string (status ));
285
308
return OMPI_ERROR ;
286
309
}
287
310
288
- procs [ i ] -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
311
+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = ep ;
289
312
}
290
313
291
314
return OMPI_SUCCESS ;
292
315
}
293
316
317
+ static void mca_pml_ucx_waitall (void * * reqs , size_t * count_p )
318
+ {
319
+ ucs_status_t status ;
320
+ size_t i ;
321
+
322
+ PML_UCX_VERBOSE (2 , "waiting for %d disconnect requests" , * count_p );
323
+ for (i = 0 ; i < * count_p ; ++ i ) {
324
+ do {
325
+ opal_progress ();
326
+ status = ucp_request_test (reqs [i ], NULL );
327
+ } while (status == UCS_INPROGRESS );
328
+ if (status != UCS_OK ) {
329
+ PML_UCX_ERROR ("disconnect request failed: %s" ,
330
+ ucs_status_string (status ));
331
+ }
332
+ ucp_request_release (reqs [i ]);
333
+ reqs [i ] = NULL ;
334
+ }
335
+
336
+ * count_p = 0 ;
337
+ }
338
+
294
339
int mca_pml_ucx_del_procs (struct ompi_proc_t * * procs , size_t nprocs )
295
340
{
341
+ ompi_proc_t * proc ;
342
+ size_t num_reqs , max_reqs ;
343
+ void * dreq , * * dreqs ;
296
344
ucp_ep_h ep ;
297
345
size_t i ;
298
346
347
+ max_reqs = ompi_pml_ucx .num_disconnect ;
348
+ if (max_reqs > nprocs ) {
349
+ max_reqs = nprocs ;
350
+ }
351
+
352
+ dreqs = malloc (sizeof (* dreqs ) * max_reqs );
353
+ if (dreqs == NULL ) {
354
+ return OMPI_ERR_OUT_OF_RESOURCE ;
355
+ }
356
+
357
+ num_reqs = 0 ;
358
+
299
359
for (i = 0 ; i < nprocs ; ++ i ) {
300
- PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , procs [i ]-> super .proc_name .vpid );
301
- ep = procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
302
- if (ep != NULL ) {
303
- ucp_ep_destroy (ep );
360
+ proc = procs [(i + OMPI_PROC_MY_NAME -> vpid ) % nprocs ];
361
+ ep = proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ];
362
+ if (ep == NULL ) {
363
+ continue ;
364
+ }
365
+
366
+ PML_UCX_VERBOSE (2 , "disconnecting from rank %d" , proc -> super .proc_name .vpid );
367
+ dreq = ucp_disconnect_nb (ep );
368
+ if (dreq != NULL ) {
369
+ if (UCS_PTR_IS_ERR (dreq )) {
370
+ PML_UCX_ERROR ("ucp_disconnect_nb(%d) failed: %s" ,
371
+ proc -> super .proc_name .vpid ,
372
+ ucs_status_string (UCS_PTR_STATUS (dreq )));
373
+ } else {
374
+ dreqs [num_reqs ++ ] = dreq ;
375
+ }
376
+ }
377
+
378
+ proc -> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
379
+
380
+ if (num_reqs >= ompi_pml_ucx .num_disconnect ) {
381
+ mca_pml_ucx_waitall (dreqs , & num_reqs );
304
382
}
305
- procs [i ]-> proc_endpoints [OMPI_PROC_ENDPOINT_TAG_PML ] = NULL ;
306
383
}
384
+
385
+ mca_pml_ucx_waitall (dreqs , & num_reqs );
386
+ free (dreqs );
387
+
307
388
opal_pmix .fence (NULL , 0 );
389
+
308
390
return OMPI_SUCCESS ;
309
391
}
310
392
@@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
321
403
322
404
int mca_pml_ucx_progress (void )
323
405
{
324
- static int inprogress = 0 ;
325
- if (inprogress != 0 ) {
326
- return 0 ;
327
- }
328
-
329
- ++ inprogress ;
330
406
ucp_worker_progress (ompi_pml_ucx .ucp_worker );
331
- -- inprogress ;
332
407
return OMPI_SUCCESS ;
333
408
}
334
409
@@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
393
468
return OMPI_SUCCESS ;
394
469
}
395
470
396
- static void
397
- mca_pml_ucx_blocking_recv_completion (void * request , ucs_status_t status ,
398
- ucp_tag_recv_info_t * info )
399
- {
400
- ompi_request_t * req = request ;
401
-
402
- PML_UCX_VERBOSE (8 , "blocking receive request %p completed with status %s tag %" PRIx64 " len %zu" ,
403
- (void * )req , ucs_status_string (status ), info -> sender_tag ,
404
- info -> length );
405
-
406
- mca_pml_ucx_set_recv_status (& req -> req_status , status , info );
407
- PML_UCX_ASSERT ( !(REQUEST_COMPLETE (req )));
408
- ompi_request_complete (req ,true);
409
- }
410
-
411
471
int mca_pml_ucx_recv (void * buf , size_t count , ompi_datatype_t * datatype , int src ,
412
472
int tag , struct ompi_communicator_t * comm ,
413
473
ompi_status_public_t * mpi_status )
414
474
{
415
475
ucp_tag_t ucp_tag , ucp_tag_mask ;
416
- ompi_request_t * req ;
476
+ ucp_tag_recv_info_t info ;
477
+ ucs_status_t status ;
478
+ void * req ;
417
479
418
480
PML_UCX_TRACE_RECV ("%s" , buf , count , datatype , src , tag , comm , "recv" );
419
481
420
482
PML_UCX_MAKE_RECV_TAG (ucp_tag , ucp_tag_mask , tag , src , comm );
421
- req = (ompi_request_t * )ucp_tag_recv_nb (ompi_pml_ucx .ucp_worker , buf , count ,
422
- mca_pml_ucx_get_datatype (datatype ),
423
- ucp_tag , ucp_tag_mask ,
424
- mca_pml_ucx_blocking_recv_completion );
425
- if (UCS_PTR_IS_ERR (req )) {
426
- PML_UCX_ERROR ("ucx recv failed: %s" , ucs_status_string (UCS_PTR_STATUS (req )));
427
- return OMPI_ERROR ;
428
- }
483
+ req = alloca (ompi_pml_ucx .request_size ) + ompi_pml_ucx .request_size ;
484
+ status = ucp_tag_recv_nbr (ompi_pml_ucx .ucp_worker , buf , count ,
485
+ mca_pml_ucx_get_datatype (datatype ),
486
+ ucp_tag , ucp_tag_mask , req );
429
487
430
488
ucp_worker_progress (ompi_pml_ucx .ucp_worker );
431
- while ( !REQUEST_COMPLETE (req ) ) {
489
+ for (;;) {
490
+ status = ucp_request_test (req , & info );
491
+ if (status != UCS_INPROGRESS ) {
492
+ mca_pml_ucx_set_recv_status_safe (mpi_status , status , & info );
493
+ return OMPI_SUCCESS ;
494
+ }
432
495
opal_progress ();
433
496
}
434
-
435
- if (mpi_status != MPI_STATUS_IGNORE ) {
436
- * mpi_status = req -> req_status ;
437
- }
438
-
439
- req -> req_complete = REQUEST_PENDING ;
440
- ucp_request_release (req );
441
- return OMPI_SUCCESS ;
442
497
}
443
498
444
499
static inline const char * mca_pml_ucx_send_mode_name (mca_pml_base_send_mode_t mode )
@@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
583
638
* matched = 1 ;
584
639
mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
585
640
} else {
641
+ opal_progress ();
586
642
* matched = 0 ;
587
643
}
588
644
return OMPI_SUCCESS ;
@@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
628
684
PML_UCX_VERBOSE (8 , "got message %p (%p)" , (void * )* message , (void * )ucp_msg );
629
685
* matched = 1 ;
630
686
mca_pml_ucx_set_recv_status_safe (mpi_status , UCS_OK , & info );
631
- } else if (UCS_PTR_STATUS (ucp_msg ) == UCS_ERR_NO_MESSAGE ) {
687
+ } else {
688
+ opal_progress ();
632
689
* matched = 0 ;
633
690
}
634
691
return OMPI_SUCCESS ;
0 commit comments