From dca13878382b41681ec73a864940144b040456b5 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Sun, 4 May 2025 12:46:16 +0200 Subject: [PATCH] coll/ucc: refactor UCC collective operations to handle MPI_IN_PLACE correctly Updated the initialization functions for allgather, allgatherv, alltoall, alltoallv to improve handling of the MPI_IN_PLACE argument. In case of MPI_IN_PLACE for these collectives corresponding datatype and count should be ignored. bot:notacherrypick Signed-off-by: Sergey Lebedev --- ompi/mca/coll/ucc/coll_ucc_allgather.c | 13 +++++++++---- ompi/mca/coll/ucc/coll_ucc_allgatherv.c | 10 +++++++--- ompi/mca/coll/ucc/coll_ucc_alltoall.c | 13 +++++++++---- ompi/mca/coll/ucc/coll_ucc_alltoallv.c | 10 +++++++--- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 4be8f953fb3..30b7e10da64 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -49,7 +54,7 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index aecd45bfe86..96fd3a460d4 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -47,7 +51,7 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index 2cacfc25b2a..7fcf9edd133 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -15,15 +15,20 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); - if (!ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size) || + if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { goto fallback; } - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -49,7 +54,7 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; } diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 75b0dd6b6b7..0b730e12b4f 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -17,10 +17,14 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i ucc_coll_req_h *req, mca_coll_ucc_req_t *coll_req) { - ucc_datatype_t ucc_sdt, ucc_rdt; + ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; + bool is_inplace = (MPI_IN_PLACE == sbuf); - ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); ucc_rdt = ompi_dtype_to_ucc_dtype(rdtype); + if (!is_inplace) { + ucc_sdt = ompi_dtype_to_ucc_dtype(sdtype); + } + if (COLL_UCC_DT_UNSUPPORTED == ucc_sdt || COLL_UCC_DT_UNSUPPORTED == ucc_rdt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", @@ -48,7 +52,7 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, const i } }; - if (MPI_IN_PLACE == sbuf) { + if (is_inplace) { coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; }