Skip to content

Commit 2c17fa7

Browse files
authored
Merge pull request #4727 from alex-mikheev/topic/pml_ucx_send_nbr_v3.1.x
ompi: pml/ucx: blocking send using ucp_tag_send_nbr
2 parents 71aba1b + 70ee536 commit 2c17fa7

File tree

2 files changed

+75
-16
lines changed

2 files changed

+75
-16
lines changed

config/ompi_check_ucx.m4

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ AC_DEFUN([OMPI_CHECK_UCX],[
9090
fi
9191
done],
9292
[ompi_check_ucx_happy="no"])
93+
94+
old_CPPFLAGS="$CPPFLAGS"
95+
AS_IF([test -n "$ompi_check_ucx_dir"],
96+
[CPPFLAGS="$CPPFLAGS -I$ompi_check_ucx_dir/include"])
97+
AC_CHECK_DECLS([ucp_tag_send_nbr],
98+
[AC_DEFINE([HAVE_UCP_TAG_SEND_NBR],[1],
99+
[have ucp_tag_send_nbr()])], [],
100+
[#include <ucp/api/ucp.h>])
101+
CPPFLAGS=$old_CPPFLAGS
102+
93103
OPAL_SUMMARY_ADD([[Transports]],[[Open UCX]],[$1],[$ompi_check_ucx_happy])
94104
fi
95105

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 65 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ mca_pml_ucx_module_t ompi_pml_ucx = {
7575
NULL /* ucp_worker */
7676
};
7777

78+
#define PML_UCX_REQ_ALLOCA() \
79+
((char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size);
80+
7881
static int mca_pml_ucx_send_worker_address(void)
7982
{
8083
ucp_address_t *address;
@@ -525,7 +528,7 @@ int mca_pml_ucx_recv(void *buf, size_t count, ompi_datatype_t *datatype, int src
525528
PML_UCX_TRACE_RECV("%s", buf, count, datatype, src, tag, comm, "recv");
526529

527530
PML_UCX_MAKE_RECV_TAG(ucp_tag, ucp_tag_mask, tag, src, comm);
528-
req = (char *)alloca(ompi_pml_ucx.request_size) + ompi_pml_ucx.request_size;
531+
req = PML_UCX_REQ_ALLOCA();
529532
status = ucp_tag_recv_nbr(ompi_pml_ucx.ucp_worker, buf, count,
530533
mca_pml_ucx_get_datatype(datatype),
531534
ucp_tag, ucp_tag_mask, req);
@@ -715,26 +718,18 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
715718
}
716719
}
717720

718-
int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst,
719-
int tag, mca_pml_base_send_mode_t mode,
720-
struct ompi_communicator_t* comm)
721+
static inline __opal_attribute_always_inline__ int
722+
mca_pml_ucx_send_nb(ucp_ep_h ep, const void *buf, size_t count,
723+
ompi_datatype_t *datatype, ucp_datatype_t ucx_datatype,
724+
ucp_tag_t tag, mca_pml_base_send_mode_t mode,
725+
ucp_send_callback_t cb)
721726
{
722727
ompi_request_t *req;
723-
ucp_ep_h ep;
724-
725-
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
726-
mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
727-
728-
ep = mca_pml_ucx_get_ep(comm, dst);
729-
if (OPAL_UNLIKELY(NULL == ep)) {
730-
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
731-
return OMPI_ERROR;
732-
}
733728

734729
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
735730
mca_pml_ucx_get_datatype(datatype),
736-
PML_UCX_MAKE_SEND_TAG(tag, comm),
737-
mode, mca_pml_ucx_send_completion);
731+
tag, mode,
732+
mca_pml_ucx_send_completion);
738733

739734
if (OPAL_LIKELY(req == NULL)) {
740735
return OMPI_SUCCESS;
@@ -749,6 +744,60 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
749744
}
750745
}
751746

747+
#if HAVE_DECL_UCP_TAG_SEND_NBR
748+
static inline __opal_attribute_always_inline__ int
749+
mca_pml_ucx_send_nbr(ucp_ep_h ep, const void *buf, size_t count,
750+
ucp_datatype_t ucx_datatype, ucp_tag_t tag)
751+
752+
{
753+
void *req;
754+
ucs_status_t status;
755+
756+
req = PML_UCX_REQ_ALLOCA();
757+
status = ucp_tag_send_nbr(ep, buf, count, ucx_datatype, tag, req);
758+
if (OPAL_LIKELY(status == UCS_OK)) {
759+
return OMPI_SUCCESS;
760+
}
761+
762+
ucp_worker_progress(ompi_pml_ucx.ucp_worker);
763+
while ((status = ucp_request_check_status(req)) == UCS_INPROGRESS) {
764+
opal_progress();
765+
}
766+
767+
return OPAL_LIKELY(UCS_OK == status) ? OMPI_SUCCESS : OMPI_ERROR;
768+
}
769+
#endif
770+
771+
int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, int dst,
772+
int tag, mca_pml_base_send_mode_t mode,
773+
struct ompi_communicator_t* comm)
774+
{
775+
ucp_ep_h ep;
776+
777+
PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm,
778+
mode == MCA_PML_BASE_SEND_BUFFERED ? "bsend" : "send");
779+
780+
ep = mca_pml_ucx_get_ep(comm, dst);
781+
if (OPAL_UNLIKELY(NULL == ep)) {
782+
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
783+
return OMPI_ERROR;
784+
}
785+
786+
#if HAVE_DECL_UCP_TAG_SEND_NBR
787+
if (OPAL_LIKELY((MCA_PML_BASE_SEND_BUFFERED != mode) &&
788+
(MCA_PML_BASE_SEND_SYNCHRONOUS != mode))) {
789+
return mca_pml_ucx_send_nbr(ep, buf, count,
790+
mca_pml_ucx_get_datatype(datatype),
791+
PML_UCX_MAKE_SEND_TAG(tag, comm));
792+
}
793+
#endif
794+
795+
return mca_pml_ucx_send_nb(ep, buf, count, datatype,
796+
mca_pml_ucx_get_datatype(datatype),
797+
PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
798+
mca_pml_ucx_send_completion);
799+
}
800+
752801
int mca_pml_ucx_iprobe(int src, int tag, struct ompi_communicator_t* comm,
753802
int *matched, ompi_status_public_t* mpi_status)
754803
{

0 commit comments

Comments
 (0)