Skip to content

Commit a8f65d8

Browse files
authored
Merge pull request #3021 from alex-mikheev/topic/pml_ucx_bsend_v2.0.x
ompi: pml ucx: add support for the buffered send
2 parents fefd4d1 + 22d688f commit a8f65d8

File tree

3 files changed

+127
-16
lines changed

3 files changed

+127
-16
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 103 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "opal/runtime/opal.h"
1616
#include "opal/mca/pmix/pmix.h"
1717
#include "ompi/message/message.h"
18+
#include "ompi/mca/pml/base/pml_base_bsend.h"
1819
#include "pml_ucx_request.h"
1920

2021
#include <inttypes.h>
@@ -506,15 +507,80 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
506507
req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND;
507508
req->buffer = (void *)buf;
508509
req->count = count;
509-
req->datatype = mca_pml_ucx_get_datatype(datatype);
510510
req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm);
511511
req->send.mode = mode;
512512
req->send.ep = ep;
513+
if (MCA_PML_BASE_SEND_BUFFERED == mode) {
514+
req->ompi_datatype = datatype;
515+
OBJ_RETAIN(datatype);
516+
} else {
517+
req->datatype = mca_pml_ucx_get_datatype(datatype);
518+
}
513519

514520
*request = &req->ompi;
515521
return OMPI_SUCCESS;
516522
}
517523

524+
static int
525+
mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
526+
ompi_datatype_t *datatype, uint64_t pml_tag)
527+
{
528+
ompi_request_t *req;
529+
void *packed_data;
530+
size_t packed_length;
531+
size_t offset;
532+
uint32_t iov_count;
533+
struct iovec iov;
534+
opal_convertor_t opal_conv;
535+
536+
OBJ_CONSTRUCT(&opal_conv, opal_convertor_t);
537+
opal_convertor_copy_and_prepare_for_recv(ompi_proc_local_proc->super.proc_convertor,
538+
&datatype->super, count, buf, 0,
539+
&opal_conv);
540+
opal_convertor_get_packed_size(&opal_conv, &packed_length);
541+
542+
packed_data = mca_pml_base_bsend_request_alloc_buf(packed_length);
543+
if (OPAL_UNLIKELY(NULL == packed_data)) {
544+
OBJ_DESTRUCT(&opal_conv);
545+
PML_UCX_ERROR("bsend: failed to allocate buffer");
546+
return OMPI_ERR_OUT_OF_RESOURCE;
547+
}
548+
549+
iov_count = 1;
550+
iov.iov_base = packed_data;
551+
iov.iov_len = packed_length;
552+
553+
PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d\n", packed_data, packed_length);
554+
offset = 0;
555+
opal_convertor_set_position(&opal_conv, &offset);
556+
if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
557+
mca_pml_base_bsend_request_free(packed_data);
558+
OBJ_DESTRUCT(&opal_conv);
559+
PML_UCX_ERROR("bsend: failed to pack user datatype");
560+
return OMPI_ERROR;
561+
}
562+
563+
OBJ_DESTRUCT(&opal_conv);
564+
565+
req = (ompi_request_t*)ucp_tag_send_nb(ep, packed_data, packed_length,
566+
ucp_dt_make_contig(1), pml_tag,
567+
mca_pml_ucx_bsend_completion);
568+
if (NULL == req) {
569+
/* request was completed in place */
570+
mca_pml_base_bsend_request_free(packed_data);
571+
return OMPI_SUCCESS;
572+
}
573+
574+
if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
575+
mca_pml_base_bsend_request_free(packed_data);
576+
PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
577+
return OMPI_ERROR;
578+
}
579+
580+
req->req_complete_cb_data = packed_data;
581+
return OMPI_SUCCESS;
582+
}
583+
518584
int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
519585
int dst, int tag, mca_pml_base_send_mode_t mode,
520586
struct ompi_communicator_t* comm,
@@ -523,8 +589,10 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
523589
ompi_request_t *req;
524590
ucp_ep_h ep;
525591

526-
PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode,
527-
comm, (void*)request)
592+
PML_UCX_TRACE_SEND("i%ssend request *%p",
593+
buf, count, datatype, dst, tag, mode, comm,
594+
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
595+
(void*)request)
528596

529597
/* TODO special care to sync/buffered send */
530598

@@ -534,6 +602,13 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
534602
return OMPI_ERROR;
535603
}
536604

605+
/* Special care to sync/buffered send */
606+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
607+
*request = &ompi_pml_ucx.completed_send_req;
608+
return mca_pml_ucx_bsend(ep, buf, count, datatype,
609+
PML_UCX_MAKE_SEND_TAG(tag, comm));
610+
}
611+
537612
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
538613
mca_pml_ucx_get_datatype(datatype),
539614
PML_UCX_MAKE_SEND_TAG(tag, comm),
@@ -559,16 +634,21 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
559634
ompi_request_t *req;
560635
ucp_ep_h ep;
561636

562-
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send");
563-
564-
/* TODO special care to sync/buffered send */
637+
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
638+
mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
565639

566640
ep = mca_pml_ucx_get_ep(comm, dst);
567641
if (OPAL_UNLIKELY(NULL == ep)) {
568642
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
569643
return OMPI_ERROR;
570644
}
571645

646+
/* Special care to sync/buffered send */
647+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
648+
return mca_pml_ucx_bsend(ep, buf, count, datatype,
649+
PML_UCX_MAKE_SEND_TAG(tag, comm));
650+
}
651+
572652
req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
573653
mca_pml_ucx_get_datatype(datatype),
574654
PML_UCX_MAKE_SEND_TAG(tag, comm),
@@ -729,6 +809,7 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
729809
mca_pml_ucx_persistent_request_t *preq;
730810
ompi_request_t *tmp_req;
731811
size_t i;
812+
int rc;
732813

733814
for (i = 0; i < count; ++i) {
734815
preq = (mca_pml_ucx_persistent_request_t *)requests[i];
@@ -743,12 +824,22 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
743824
mca_pml_ucx_request_reset(&preq->ompi);
744825

745826
if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
746-
/* TODO special care to sync/buffered send */
747-
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
748-
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
749-
preq->count, preq->datatype,
750-
preq->tag,
751-
mca_pml_ucx_psend_completion);
827+
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) {
828+
PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq);
829+
rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count,
830+
preq->ompi_datatype, preq->tag);
831+
if (OMPI_SUCCESS != rc) {
832+
return rc;
833+
}
834+
/* pretend that we got immediate completion */
835+
tmp_req = NULL;
836+
} else {
837+
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
838+
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
839+
preq->count, preq->datatype,
840+
preq->tag,
841+
mca_pml_ucx_psend_completion);
842+
}
752843
} else {
753844
PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
754845
tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,

ompi/mca/pml/ucx/pml_ucx_request.c

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ static int mca_pml_ucx_request_free(ompi_request_t **rptr)
2424

2525
*rptr = MPI_REQUEST_NULL;
2626
mca_pml_ucx_request_reset(req);
27-
ucp_request_release(req);
27+
ucp_request_free(req);
2828
return OMPI_SUCCESS;
2929
}
3030

@@ -46,6 +46,18 @@ void mca_pml_ucx_send_completion(void *request, ucs_status_t status)
4646
ompi_request_complete(req, true);
4747
}
4848

49+
void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status)
50+
{
51+
ompi_request_t *req = request;
52+
53+
PML_UCX_VERBOSE(8, "bsend request %p buffer %p completed with status %s", (void*)req,
54+
req->req_complete_cb_data, ucs_status_string(status));
55+
mca_pml_base_bsend_request_free(req->req_complete_cb_data);
56+
mca_pml_ucx_set_send_status(&req->req_status, status);
57+
PML_UCX_ASSERT( !(REQUEST_COMPLETE(req)));
58+
mca_pml_ucx_request_free(&req);
59+
}
60+
4961
void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
5062
ucp_tag_recv_info_t *info)
5163
{
@@ -74,7 +86,7 @@ void mca_pml_ucx_persistent_request_complete(mca_pml_ucx_persistent_request_t *p
7486
ompi_request_complete(&preq->ompi, true);
7587
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
7688
mca_pml_ucx_request_reset(tmp_req);
77-
ucp_request_release(tmp_req);
89+
ucp_request_free(tmp_req);
7890
}
7991

8092
static inline void mca_pml_ucx_preq_completion(ompi_request_t *tmp_req)
@@ -151,7 +163,10 @@ static int mca_pml_ucx_persistent_request_free(ompi_request_t **rptr)
151163
preq->ompi.req_state = OMPI_REQUEST_INVALID;
152164
if (tmp_req != NULL) {
153165
mca_pml_ucx_persistent_request_detach(preq, tmp_req);
154-
ucp_request_release(tmp_req);
166+
ucp_request_free(tmp_req);
167+
}
168+
if (MCA_PML_BASE_SEND_BUFFERED == preq->send.mode) {
169+
OBJ_RELEASE(preq->ompi_datatype);
155170
}
156171
PML_UCX_FREELIST_RETURN(&ompi_pml_ucx.persistent_reqs, &preq->ompi.super);
157172
*rptr = MPI_REQUEST_NULL;

ompi/mca/pml/ucx/pml_ucx_request.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ struct pml_ucx_persistent_request {
9696
unsigned flags;
9797
void *buffer;
9898
size_t count;
99-
ucp_datatype_t datatype;
99+
union {
100+
ucp_datatype_t datatype;
101+
ompi_datatype_t *ompi_datatype;
102+
};
100103
ucp_tag_t tag;
101104
struct {
102105
mca_pml_base_send_mode_t mode;
@@ -115,6 +118,8 @@ void mca_pml_ucx_recv_completion(void *request, ucs_status_t status,
115118

116119
void mca_pml_ucx_psend_completion(void *request, ucs_status_t status);
117120

121+
void mca_pml_ucx_bsend_completion(void *request, ucs_status_t status);
122+
118123
void mca_pml_ucx_precv_completion(void *request, ucs_status_t status,
119124
ucp_tag_recv_info_t *info);
120125

0 commit comments

Comments
 (0)