Skip to content

Commit fb67c96

Browse files
authored
Merge pull request #2944 from alex-mikheev/topic/pml_ucx_bsend
ompi: pml ucx: add support for the buffered send
2 parents 4ef6563 + b015c8b commit fb67c96

File tree

3 files changed

+131
-20
lines changed

3 files changed

+131
-20
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "opal/runtime/opal.h"
1616
#include "opal/mca/pmix/pmix.h"
1717
#include "ompi/message/message.h"
18+
#include "ompi/mca/pml/base/pml_base_bsend.h"
1819
#include "pml_ucx_request.h"
1920

2021
#include <inttypes.h>
@@ -333,7 +334,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
333334
ucs_status_t status;
334335
size_t i;
335336

336-
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", *count_p);
337+
PML_UCX_VERBOSE(2, "waiting for %d disconnect requests", (int)*count_p);
337338
for (i = 0; i < *count_p; ++i) {
338339
do {
339340
opal_progress();
@@ -343,7 +344,7 @@ static void mca_pml_ucx_waitall(void **reqs, size_t *count_p)
343344
PML_UCX_ERROR("disconnect request failed: %s",
344345
ucs_status_string(status));
345346
}
346-
ucp_request_release(reqs[i]);
347+
ucp_request_free(reqs[i]);
347348
reqs[i] = NULL;
348349
}
349350

@@ -391,7 +392,7 @@ int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs)
391392

392393
proc->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = NULL;
393394

394-
if (num_reqs >= ompi_pml_ucx.num_disconnect) {
395+
if ((int)num_reqs >= ompi_pml_ucx.num_disconnect) {
395396
mca_pml_ucx_waitall(dreqs, &num_reqs);
396397
}
397398
}
@@ -494,7 +495,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
494495
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
495496

496497
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
497-
req = alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
498+
req = (char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
498499
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
499500
mca_pml_ucx_get_datatype(datatype),
500501
ucp_tag, ucp_tag_mask, req);
@@ -556,15 +557,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
556557
req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
557558
req->buffer = (void *)buf;
558559
req->count = count;
559-
req->datatype = mca_pml_ucx_get_datatype(datatype);
560560
req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
561561
req->send.mode = mode;
562562
req->send.ep = ep;
563+
if (MCA_PML_BASE_SEND_BUFFERED == mode) {
564+
req->ompi_datatype = datatype;
565+
OBJ_RETAIN(datatype);
566+
} else {
567+
req->datatype = mca_pml_ucx_get_datatype(datatype);
568+
}
563569

564570
*request = &req->ompi;
565571
return OMPI_SUCCESS;
566572
}
567573

574+
static int
575+
mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
576+
ompi_datatype_t *datatype, uint64_t pml_tag)
577+
{
578+
ompi_request_t *req;
579+
void *packed_data;
580+
size_t packed_length;
581+
size_t offset;
582+
uint32_t iov_count;
583+
struct iovec iov;
584+
opal_convertor_t opal_conv;
585+
586+
OBJ_CONSTRUCT(&opal_conv, opal_convertor_t);
587+
opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
588+
&datatype->super, count, buf, 0,
589+
&opal_conv);
590+
opal_convertor_get_packed_size(&opal_conv, &packed_length);
591+
592+
packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length);
593+
if (OPAL_UNLIKELY(NULL == packed_data)) {
594+
OBJ_DESTRUCT(&opal_conv);
595+
PML_UCX_ERROR("bsend: failed to allocate buffer");
596+
return OMPI_ERR_OUT_OF_RESOURCE;
597+
}
598+
599+
iov_count = 1;
600+
iov.iov_base = packed_data;
601+
iov.iov_len = packed_length;
602+
603+
PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d\n", packed_data, packed_length);
604+
offset = 0;
605+
opal_convertor_set_position(&opal_conv, &offset);
606+
if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
607+
mca_pml_base_bsend_request_free(packed_data);
608+
OBJ_DESTRUCT(&opal_conv);
609+
PML_UCX_ERROR("bsend: failed to pack user datatype");
610+
return OMPI_ERROR;
611+
}
612+
613+
OBJ_DESTRUCT(&opal_conv);
614+
615+
req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length,
616+
ucp_dt_make_contig(1), pml_tag,
617+
mca_pml_ucx_bsend_completion);
618+
if (NULL == req) {
619+
/* request was completed in place */
620+
mca_pml_base_bsend_request_free(packed_data);
621+
return OMPI_SUCCESS;
622+
}
623+
624+
if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
625+
mca_pml_base_bsend_request_free(packed_data);
626+
PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
627+
return OMPI_ERROR;
628+
}
629+
630+
req->req_complete_cb_data = packed_data;
631+
return OMPI_SUCCESS;
632+
}
633+
568634
int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
569635
int dst, int tag, mca_pml_base_send_mode_t mode,
570636
struct ompi_communicator_t* comm,
@@ -573,8 +639,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
573639
ompi_request_t *req;
574640
ucp_ep_h ep;
575641

576-
PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode,
577-
comm, (void*)request)
642+
PML_UCX_TRACE_SEND("i%ssend request *%p",
643+
buf, count, datatype, dst, tag, mode, comm,
644+
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
645+
(void*)request)
578646

579647
/* TODO special care to sync/buffered send */
580648

@@ -584,6 +652,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
584652
return OMPI_ERROR;
585653
}
586654

655+
/* Special care to sync/buffered send */
656+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
657+
*request = &ompi_pml_ucx.completed_send_req;
658+
return mca_pml_ucx_bsend(ep, buf, count, datatype,
659+
PML_UCX_MAKE_SEND_TAG(tag, comm));
660+
}
661+
587662
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
588663
mca_pml_ucx_get_datatype(datatype),
589664
PML_UCX_MAKE_SEND_TAG(tag, comm),
@@ -609,16 +684,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
609684
ompi_request_t *req;
610685
ucp_ep_h ep;
611686

612-
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send");
613-
614-
/* TODO special care to sync/buffered send */
687+
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
688+
mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
615689

616690
ep = mca_pml_ucx_get_ep(comm, dst);
617691
if (OPAL_UNLIKELY(NULL == ep)) {
618692
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
619693
return OMPI_ERROR;
620694
}
621695

696+
/* Special care to sync/buffered send */
697+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
698+
return mca_pml_ucx_bsend(ep, buf, count, datatype,
699+
PML_UCX_MAKE_SEND_TAG(tag, comm));
700+
}
701+
622702
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
623703
mca_pml_ucx_get_datatype(datatype),
624704
PML_UCX_MAKE_SEND_TAG(tag, comm),
@@ -781,6 +861,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
781861
mca_pml_ucx_persistent_request_t *preq;
782862
ompi_request_t *tmp_req;
783863
size_t i;
864+
int rc;
784865

785866
for (i = 0; i < count; ++i) {
786867
preq = (mca_pml_ucx_persistent_request_t *)requests[i];
@@ -795,12 +876,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
795876
mca_pml_ucx_request_reset(&preq->ompi);
796877

797878
if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
798-
/* TODO special care to sync/buffered send */
799-
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
800-
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
801-
preq->count, preq->datatype,
802-
preq->tag,
803-
mca_pml_ucx_psend_completion);
879+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) {
880+
PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq);
881+
rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count,
882+
preq->ompi_datatype, preq->tag);
883+
if (OMPI_SUCCESS != rc) {
884+
return rc;
885+
}
886+
/* pretend that we got immediate completion */
887+
tmp_req = NULL;
888+
} else {
889+
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
890+
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
891+
preq->count, preq->datatype,
892+
preq->tag,
893+
mca_pml_ucx_psend_completion);
894+
}
804895
} else {
805896
PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
806897
tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,

ompi/mca/pml/ucx/pml_ucx_request.c

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr)
2424

2525
*rptr = MPI_REQUEST_NULL;
2626
mca_pml_ucx_request_reset(req);
27-
ucp_request_release(req);
27+
ucp_request_free(req);
2828
return OMPI_SUCCESS;
2929
}
3030

@@ -46,6 +46,18 @@ void mca_pml_ucx_send_completion(void *request, ucs_status_t status)
4646
ompi_request_complete(req, true);
4747
}
4848

49+
void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status)
50+
{
51+
ompi_request_t *req = request;
52+
53+
PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req,
54+
req->req_complete_cb_data, ucs_status_string(status));
55+
mca_pml_base_bsend_request_free(req->req_complete_cb_data);
56+
mca_pml_ucx_set_send_status(&req->req_status, status);
57+
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
58+
mca_pml_ucx_request_free(&req);
59+
}
60+
4961
void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
5062
ucp_tag_recv_info_t *info)
5163
{
@@ -75,7 +87,7 @@ mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *preq,
7587
ompi_request_complete(&preq->ompi, true);
7688
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
7789
mca_pml_ucx_request_reset(tmp_req);
78-
ucp_request_release(tmp_req);
90+
ucp_request_free(tmp_req);
7991
}
8092

8193
static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
@@ -152,7 +164,10 @@ static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr)
152164
preq->ompi.req_state = OMPI_REQUEST_INVALID;
153165
if (tmp_req != NULL) {
154166
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
155-
ucp_request_release(tmp_req);
167+
ucp_request_free(tmp_req);
168+
}
169+
if (MCA_PML_BASE_SEND_BUFFERED == preq->send.mode) {
170+
OBJ_RELEASE(preq->ompi_datatype);
156171
}
157172
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super);
158173
*rptr = MPI_REQUEST_NULL;

ompi/mca/pml/ucx/pml_ucx_request.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,10 @@ struct pml_ucx_persistent_request {
9999
unsigned flags;
100100
void *buffer;
101101
size_t count;
102-
ucp_datatype_t datatype;
102+
union {
103+
ucp_datatype_t datatype;
104+
ompi_datatype_t *ompi_datatype;
105+
};
103106
ucp_tag_t tag;
104107
struct {
105108
mca_pml_base_send_mode_t mode;
@@ -118,6 +121,8 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
118121

119122
void mca_pml_ucx_psend_completion(void *request, ucs_status_t status);
120123

124+
void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status);
125+
121126
void mca_pml_ucx_precv_completion(void *request, ucs_status_t status,
122127
ucp_tag_recv_info_t *info);
123128

0 commit comments

Comments
 (0)