Skip to content

Commit b661307

Browse files
authored
Merge pull request #2218 from yosefe/topic/ucx-pml-spml-update
ucx: adapt pml_ucx and spml_ucx to new UCX APIs
2 parents 958e29f + 05ca466 commit b661307

File tree

7 files changed

+216
-71
lines changed

7 files changed

+216
-71
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7070
1ul << (PML_UCX_TAG_BITS - 1),
7171
1ul << (PML_UCX_CONTEXT_BITS),
7272
},
73-
NULL,
74-
NULL
73+
NULL, /* ucp_context */
74+
NULL /* ucp_worker */
7575
};
7676

7777
static int mca_pml_ucx_send_worker_address(void)
@@ -116,6 +116,7 @@ static int mca_pml_ucx_recv_worker_address(ompi_proc_t *proc,
116116

117117
int mca_pml_ucx_open(void)
118118
{
119+
ucp_context_attr_t attr;
119120
ucp_params_t params;
120121
ucp_config_t *config;
121122
ucs_status_t status;
@@ -128,10 +129,17 @@ int mca_pml_ucx_open(void)
128129
return OMPI_ERROR;
129130
}
130131

132+
/* Initialize UCX context */
133+
params.field_mask = UCP_PARAM_FIELD_FEATURES |
134+
UCP_PARAM_FIELD_REQUEST_SIZE |
135+
UCP_PARAM_FIELD_REQUEST_INIT |
136+
UCP_PARAM_FIELD_REQUEST_CLEANUP |
137+
UCP_PARAM_FIELD_TAG_SENDER_MASK;
131138
params.features = UCP_FEATURE_TAG;
132139
params.request_size = sizeof(ompi_request_t);
133140
params.request_init = mca_pml_ucx_request_init;
134141
params.request_cleanup = mca_pml_ucx_request_cleanup;
142+
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
135143

136144
status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
137145
ucp_config_release(config);
@@ -140,6 +148,17 @@ int mca_pml_ucx_open(void)
140148
return OMPI_ERROR;
141149
}
142150

151+
/* Query UCX attributes */
152+
attr.field_mask = UCP_ATTR_FIELD_REQUEST_SIZE;
153+
status = ucp_context_query(ompi_pml_ucx.ucp_context, &attr);
154+
if (UCS_OK != status) {
155+
ucp_cleanup(ompi_pml_ucx.ucp_context);
156+
ompi_pml_ucx.ucp_context = NULL;
157+
return OMPI_ERROR;
158+
}
159+
160+
ompi_pml_ucx.request_size = attr.request_size;
161+
143162
return OMPI_SUCCESS;
144163
}
145164

@@ -163,7 +182,7 @@ int mca_pml_ucx_init(void)
163182

164183
/* TODO check MPI thread mode */
165184
status = ucp_worker_create(ompi_pml_ucx.ucp_context, UCS_THREAD_MODE_SINGLE,
166-
&ompi_pml_ucx.ucp_worker);
185+
&ompi_pml_ucx.ucp_worker);
167186
if (UCS_OK != status) {
168187
return OMPI_ERROR;
169188
}
@@ -252,6 +271,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
252271
{
253272
ucp_address_t *address;
254273
ucs_status_t status;
274+
ompi_proc_t *proc;
255275
size_t addrlen;
256276
ucp_ep_h ep;
257277
size_t i;
@@ -264,47 +284,109 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs)
264284
}
265285

266286
for (i = 0; i < nprocs; ++i) {
267-
ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen);
287+
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
288+
289+
ret = mca_pml_ucx_recv_worker_address(proc, &address, &addrlen);
268290
if (ret < 0) {
269-
PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid);
291+
PML_UCX_ERROR("Failed to receive worker address from proc: %d",
292+
proc->super.proc_name.vpid);
270293
return ret;
271294
}
272295

273-
if (procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
274-
PML_UCX_VERBOSE(3, "already connected to proc. %d", procs[i]->super.proc_name.vpid);
296+
if (proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]) {
297+
PML_UCX_VERBOSE(3, "already connected to proc. %d", proc->super.proc_name.vpid);
275298
continue;
276299
}
277300

278-
PML_UCX_VERBOSE(2, "connecting to proc. %d", procs[i]->super.proc_name.vpid);
301+
PML_UCX_VERBOSE(2, "connecting to proc. %d", proc->super.proc_name.vpid);
279302
status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep);
280303
free(address);
281304

282305
if (UCS_OK != status) {
283-
PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid,
284-
ucs_status_string(status));
306+
PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc->super.proc_name.vpid,
307+
ucs_status_string(status));
285308
return OMPI_ERROR;
286309
}
287310

288-
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
311+
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep;
289312
}
290313

291314
return OMPI_SUCCESS;
292315
}
293316

317+
static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
318+
{
319+
ucs_status_t status;
320+
size_t i;
321+
322+
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
323+
for (i = 0; i < *count_p; ++i) {
324+
do {
325+
opal_progress();
326+
status = ucp_request_test(reqs[i], NULL);
327+
} while (status == UCS_INPROGRESS);
328+
if (status != UCS_OK) {
329+
PML_UCX_ERROR("disconnect request failed: %s",
330+
ucs_status_string(status));
331+
}
332+
ucp_request_release(reqs[i]);
333+
reqs[i] = NULL;
334+
}
335+
336+
*count_p = 0;
337+
}
338+
294339
int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
295340
{
341+
ompi_proc_t *proc;
342+
size_t num_reqs, max_reqs;
343+
void *dreq, **dreqs;
296344
ucp_ep_h ep;
297345
size_t i;
298346

347+
max_reqs = ompi_pml_ucx.num_disconnect;
348+
if (max_reqs > nprocs) {
349+
max_reqs = nprocs;
350+
}
351+
352+
dreqs = malloc(sizeof(*dreqs) * max_reqs);
353+
if (dreqs == NULL) {
354+
return OMPI_ERR_OUT_OF_RESOURCE;
355+
}
356+
357+
num_reqs = 0;
358+
299359
for (i = 0; i < nprocs; ++i) {
300-
PML_UCX_VERBOSE(2, "disconnecting from rank %d", procs[i]->super.proc_name.vpid);
301-
ep = procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
302-
if (ep != NULL) {
303-
ucp_ep_destroy(ep);
360+
proc = procs[(i + OMPI_PROC_MY_NAME->vpid) % nprocs];
361+
ep = proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML];
362+
if (ep == NULL) {
363+
continue;
364+
}
365+
366+
PML_UCX_VERBOSE(2, "disconnecting from rank %d", proc->super.proc_name.vpid);
367+
dreq = ucp_disconnect_nb(ep);
368+
if (dreq != NULL) {
369+
if (UCS_PTR_IS_ERR(dreq)) {
370+
PML_UCX_ERROR("ucp_disconnect_nb(%d) failed: %s",
371+
proc->super.proc_name.vpid,
372+
ucs_status_string(UCS_PTR_STATUS(dreq)));
373+
} else {
374+
dreqs[num_reqs++] = dreq;
375+
}
376+
}
377+
378+
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
379+
380+
if (num_reqs >= ompi_pml_ucx.num_disconnect) {
381+
mca_pml_ucx_waitall(dreqs, &num_reqs);
304382
}
305-
procs[i]->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
306383
}
384+
385+
mca_pml_ucx_waitall(dreqs, &num_reqs);
386+
free(dreqs);
387+
307388
opal_pmix.fence(NULL, 0);
389+
308390
return OMPI_SUCCESS;
309391
}
310392

@@ -321,14 +403,7 @@ int mca_pml_ucx_enable(bool enable)
321403

322404
int mca_pml_ucx_progress(void)
323405
{
324-
static int inprogress = 0;
325-
if (inprogress != 0) {
326-
return 0;
327-
}
328-
329-
++inprogress;
330406
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
331-
--inprogress;
332407
return OMPI_SUCCESS;
333408
}
334409

@@ -393,52 +468,32 @@ int mca_pml_ucx_irecv(void *buf, size_t count, ompi_datatype_t *datatype,
393468
return OMPI_SUCCESS;
394469
}
395470

396-
static void
397-
mca_pml_ucx_blocking_recv_completion(void *request, ucs_status_t status,
398-
ucp_tag_recv_info_t *info)
399-
{
400-
ompi_request_t *req = request;
401-
402-
PML_UCX_VERBOSE(8, "blocking receive request %p completed with status %s tag %"PRIx64" len %zu",
403-
(void*)req, ucs_status_string(status), info->sender_tag,
404-
info->length);
405-
406-
mca_pml_ucx_set_recv_status(&req->req_status, status, info);
407-
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
408-
ompi_request_complete(req,true);
409-
}
410-
411471
int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src,
412472
int tag, struct ompi_communicator_t* comm,
413473
ompi_status_public_t* mpi_status)
414474
{
415475
ucp_tag_t ucp_tag, ucp_tag_mask;
416-
ompi_request_t *req;
476+
ucp_tag_recv_info_t info;
477+
ucs_status_t status;
478+
void *req;
417479

418480
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
419481

420482
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
421-
req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker, buf, count,
422-
mca_pml_ucx_get_datatype(datatype),
423-
ucp_tag, ucp_tag_mask,
424-
mca_pml_ucx_blocking_recv_completion);
425-
if (UCS_PTR_IS_ERR(req)) {
426-
PML_UCX_ERROR("ucx recv failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
427-
return OMPI_ERROR;
428-
}
483+
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
484+
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
485+
mca_pml_ucx_get_datatype(datatype),
486+
ucp_tag, ucp_tag_mask, req);
429487

430488
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
431-
while ( !REQUEST_COMPLETE(req) ) {
489+
for (;;) {
490+
status = ucp_request_test(req, &info);
491+
if (status != UCS_INPROGRESS) {
492+
mca_pml_ucx_set_recv_status_safe(mpi_status, status, &info);
493+
return OMPI_SUCCESS;
494+
}
432495
opal_progress();
433496
}
434-
435-
if (mpi_status != MPI_STATUS_IGNORE) {
436-
*mpi_status = req->req_status;
437-
}
438-
439-
req->req_complete = REQUEST_PENDING;
440-
ucp_request_release(req);
441-
return OMPI_SUCCESS;
442497
}
443498

444499
static inline const char *mca_pml_ucx_send_mode_name(mca_pml_base_send_mode_t mode)
@@ -583,6 +638,7 @@ int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
583638
*matched = 1;
584639
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
585640
} else {
641+
opal_progress();
586642
*matched = 0;
587643
}
588644
return OMPI_SUCCESS;
@@ -628,7 +684,8 @@ int mca_pml_ucx_improbe(int src, int tag, struct ompi_communicator_t* comm,
628684
PML_UCX_VERBOSE(8, "got message %p (%p)", (void*)*message, (void*)ucp_msg);
629685
*matched = 1;
630686
mca_pml_ucx_set_recv_status_safe(mpi_status, UCS_OK, &info);
631-
} else if (UCS_PTR_STATUS(ucp_msg) == UCS_ERR_NO_MESSAGE) {
687+
} else {
688+
opal_progress();
632689
*matched = 0;
633690
}
634691
return OMPI_SUCCESS;

ompi/mca/pml/ucx/pml_ucx.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ struct mca_pml_ucx_module {
4040
/* Requests */
4141
mca_pml_ucx_freelist_t persistent_reqs;
4242
ompi_request_t completed_send_req;
43+
size_t request_size;
44+
int num_disconnect;
4345

44-
/* Convertors pool */
46+
/* Converters pool */
4547
mca_pml_ucx_freelist_t convs;
4648

4749
int priority;

ompi/mca/pml/ucx/pml_ucx_component.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ static int mca_pml_ucx_component_register(void)
6363
MCA_BASE_VAR_SCOPE_LOCAL,
6464
&ompi_pml_ucx.priority);
6565

66+
ompi_pml_ucx.num_disconnect = 1;
67+
(void) mca_base_component_var_register(&mca_pml_ucx_component.pmlm_version, "num_disconnect",
68+
"How may disconnects go in parallel",
69+
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
70+
OPAL_INFO_LVL_3,
71+
MCA_BASE_VAR_SCOPE_LOCAL,
72+
&ompi_pml_ucx.num_disconnect);
6673
return 0;
6774
}
6875

ompi/mca/pml/ucx/pml_ucx_request.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ enum {
3434
#define PML_UCX_TAG_BITS 24
3535
#define PML_UCX_RANK_BITS 24
3636
#define PML_UCX_CONTEXT_BITS 16
37+
#define PML_UCX_ANY_SOURCE_MASK 0x800000000000fffful
38+
#define PML_UCX_SPECIFIC_SOURCE_MASK 0x800000fffffffffful
39+
#define PML_UCX_TAG_MASK 0x7fffff0000000000ul
3740

3841

3942
#define PML_UCX_MAKE_SEND_TAG(_tag, _comm) \
@@ -45,16 +48,16 @@ enum {
4548
#define PML_UCX_MAKE_RECV_TAG(_ucp_tag, _ucp_tag_mask, _tag, _src, _comm) \
4649
{ \
4750
if ((_src) == MPI_ANY_SOURCE) { \
48-
_ucp_tag_mask = 0x800000000000fffful; \
51+
_ucp_tag_mask = PML_UCX_ANY_SOURCE_MASK; \
4952
} else { \
50-
_ucp_tag_mask = 0x800000fffffffffful; \
53+
_ucp_tag_mask = PML_UCX_SPECIFIC_SOURCE_MASK; \
5154
} \
5255
\
5356
_ucp_tag = (((uint64_t)(_src) & UCS_MASK(PML_UCX_RANK_BITS)) << PML_UCX_CONTEXT_BITS) | \
5457
(_comm)->c_contextid; \
5558
\
5659
if ((_tag) != MPI_ANY_TAG) { \
57-
_ucp_tag_mask |= 0x7fffff0000000000ul; \
60+
_ucp_tag_mask |= PML_UCX_TAG_MASK; \
5861
_ucp_tag |= ((uint64_t)(_tag)) << (PML_UCX_RANK_BITS + PML_UCX_CONTEXT_BITS); \
5962
} \
6063
}

0 commit comments

Comments
 (0)