diff --git a/ompi/mca/coll/ucc/coll_ucc.h b/ompi/mca/coll/ucc/coll_ucc.h index 510e4796448..da2d1d2e141 100644 --- a/ompi/mca/coll/ucc/coll_ucc.h +++ b/ompi/mca/coll/ucc/coll_ucc.h @@ -1,6 +1,7 @@ /** Copyright (c) 2021 Mellanox Technologies. All rights reserved. Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + Copyright (c) 2025 Fujitsu Limited. All rights reserved. $COPYRIGHT$ Additional copyrights may follow @@ -61,6 +62,7 @@ struct mca_coll_ucc_component_t { ucc_lib_attr_t ucc_lib_attr; ucc_coll_type_t cts_requested; ucc_coll_type_t nb_cts_requested; + ucc_coll_type_t ps_cts_requested; ucc_context_h ucc_context; opal_free_list_t requests; }; @@ -132,6 +134,34 @@ struct mca_coll_ucc_module_t { mca_coll_base_module_t* previous_scatter_module; mca_coll_base_module_iscatter_fn_t previous_iscatter; mca_coll_base_module_t* previous_iscatter_module; + mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init; + mca_coll_base_module_t* previous_allreduce_init_module; + mca_coll_base_module_reduce_init_fn_t previous_reduce_init; + mca_coll_base_module_t* previous_reduce_init_module; + mca_coll_base_module_barrier_init_fn_t previous_barrier_init; + mca_coll_base_module_t* previous_barrier_init_module; + mca_coll_base_module_bcast_init_fn_t previous_bcast_init; + mca_coll_base_module_t* previous_bcast_init_module; + mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init; + mca_coll_base_module_t* previous_alltoall_init_module; + mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init; + mca_coll_base_module_t* previous_alltoallv_init_module; + mca_coll_base_module_allgather_init_fn_t previous_allgather_init; + mca_coll_base_module_t* previous_allgather_init_module; + mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init; + mca_coll_base_module_t* previous_allgatherv_init_module; + mca_coll_base_module_gather_init_fn_t previous_gather_init; + mca_coll_base_module_t* previous_gather_init_module; + mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init; + mca_coll_base_module_t* previous_gatherv_init_module; + mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init; + mca_coll_base_module_t* previous_reduce_scatter_block_init_module; + mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init; + mca_coll_base_module_t* previous_reduce_scatter_init_module; + mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init; + mca_coll_base_module_t* previous_scatterv_init_module; + mca_coll_base_module_scatter_init_fn_t previous_scatter_init; + mca_coll_base_module_t* previous_scatter_init_module; }; typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t; OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); @@ -305,5 +335,78 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, ompi_request_t** request, mca_coll_base_module_t *module); +int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_bcast_init(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module); + +int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, + size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + +int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module); + END_C_DECLS #endif diff --git a/ompi/mca/coll/ucc/coll_ucc_allgather.c b/ompi/mca/coll/ucc/coll_ucc_allgather.c index 2dd3ac68a55..2362cc038a1 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgather.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgather.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,15 +10,17 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { @@ -37,9 +40,12 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t goto fallback; } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLGATHER, .src.info = { .buffer = (void*)sbuf, @@ -55,10 +61,6 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -74,9 +76,9 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allgather"); - COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -98,9 +100,9 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp UCC_VERBOSE(3, "running ucc iallgather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -112,3 +114,29 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request, ucc_module->previous_iallgather_module); } + +int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "allgather_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allgather_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allgather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + info, request, + ucc_module->previous_allgather_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c index 68e786e0c2a..a2958496c70 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allgatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_allgatherv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,13 +10,14 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, - struct ompi_datatype_t *sdtype, - void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, - struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_allgatherv_init_common(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -36,7 +38,8 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -75,9 +78,9 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc allgatherv"); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -102,9 +105,9 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iallgatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -117,3 +120,30 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, rbuf, rcounts, rdisps, rdtype, comm, request, ucc_module->previous_iallgatherv_module); } + +int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "allgatherv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype, + rbuf, rcounts, rdisps, rdtype, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allgatherv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, + comm, info, request, + ucc_module->previous_allgatherv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_allreduce.c b/ompi/mca/coll/ucc/coll_ucc_allreduce.c index d44b93df07e..ac8c990a939 100644 --- a/ompi/mca/coll/ucc/coll_ucc_allreduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_allreduce.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,14 +10,15 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_allreduce_init_common(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; + uint64_t flags = 0; ucc_dt = ompi_dtype_to_ucc_dtype(dtype); ucc_op = ompi_op_to_ucc_op(op); @@ -30,9 +32,13 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *r op->o_name); goto fallback; } + + flags = ((MPI_IN_PLACE == sbuf) ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLREDUCE, .src.info = { .buffer = (void*)sbuf, @@ -48,10 +54,7 @@ static inline ucc_status_t mca_coll_ucc_allreduce_init(const void *sbuf, void *r }, .op = ucc_op, }; - if (MPI_IN_PLACE == sbuf) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -67,8 +70,8 @@ int mca_coll_ucc_allreduce(const void *sbuf, void *rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc allreduce"); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -90,8 +93,8 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, UCC_VERBOSE(3, "running ucc iallreduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_allreduce_init(sbuf, rbuf, count, dtype, op, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -103,3 +106,27 @@ int mca_coll_ucc_iallreduce(const void *sbuf, void *rbuf, size_t count, return ucc_module->previous_iallreduce(sbuf, rbuf, count, dtype, op, comm, request, ucc_module->previous_iallreduce_module); } + +int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "allreduce_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_allreduce_init_common(sbuf, rbuf, count, dtype, op, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback allreduce_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_allreduce_init(sbuf, rbuf, count, dtype, op, comm, info, request, + ucc_module->previous_allreduce_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoall.c b/ompi/mca/coll/ucc/coll_ucc_alltoall.c index cfb56f47418..f61171576b2 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoall.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoall.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,15 +10,17 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_alltoall_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount * comm_size)) || !ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) { @@ -37,9 +40,12 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s goto fallback; } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_ALLTOALL, .src.info = { .buffer = (void*)sbuf, @@ -55,10 +61,6 @@ static inline ucc_status_t mca_coll_ucc_alltoall_init(const void *sbuf, size_t s } }; - if (is_inplace) { - coll.mask = UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -74,16 +76,16 @@ int mca_coll_ucc_alltoall(const void *sbuf, size_t scount, struct ompi_datatype_ ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc alltoall"); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; fallback: UCC_VERBOSE(3, "running fallback alltoall"); return ucc_module->previous_alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, - comm, ucc_module->previous_alltoall_module); + comm, ucc_module->previous_alltoall_module); } int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, @@ -98,9 +100,9 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype UCC_VERBOSE(3, "running ucc ialltoall"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoall_init(sbuf, scount, sdtype, - rbuf, rcount, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -112,3 +114,29 @@ int mca_coll_ucc_ialltoall(const void *sbuf, size_t scount, struct ompi_datatype return ucc_module->previous_ialltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, request, ucc_module->previous_ialltoall_module); } + +int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "alltoall_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_alltoall_init_common(sbuf, scount, sdtype, + rbuf, rcount, rdtype, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback alltoall_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_alltoall_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm, + info, request, + ucc_module->previous_alltoall_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c index 1e9e311cf94..ce9b7e03fee 100644 --- a/ompi/mca/coll/ucc/coll_ucc_alltoallv.c +++ b/ompi/mca/coll/ucc/coll_ucc_alltoallv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,13 +10,14 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, - ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, - void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, - struct ompi_datatype_t *rdtype, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_alltoallv_init_common(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -37,7 +39,8 @@ static inline ucc_status_t mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_co /* Assumes that send counts/displs and recv counts/displs are both 32-bit or both 64-bit */ flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(sdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -77,9 +80,9 @@ int mca_coll_ucc_alltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc alltoallv"); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -104,9 +107,9 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc ialltoallv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -116,6 +119,34 @@ int mca_coll_ucc_ialltoallv(const void *sbuf, ompi_count_array_t scounts, mca_coll_ucc_req_free((ompi_request_t **)&coll_req); } return ucc_module->previous_ialltoallv(sbuf, scounts, sdisps, sdtype, - rbuf, rcounts, rdisps, rdtype, + rbuf, rcounts, rdisps, rdtype, comm, request, ucc_module->previous_ialltoallv_module); } + +int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, + struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "alltoallv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_alltoallv_init_common(sbuf, scounts, sdisps, sdtype, + rbuf, rcounts, rdisps, rdtype, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback alltoallv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_alltoallv_init(sbuf, scounts, sdisps, sdtype, rbuf, rcounts, rdisps, + rdtype, comm, info, request, + ucc_module->previous_alltoallv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_barrier.c b/ompi/mca/coll/ucc/coll_ucc_barrier.c index 010ca177fdf..da886e56f54 100644 --- a/ompi/mca/coll/ucc/coll_ucc_barrier.c +++ b/ompi/mca/coll/ucc/coll_ucc_barrier.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,15 +9,20 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_barrier_init(mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_barrier_init_common(bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { + uint64_t flags = 0; + + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_BARRIER }; + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -30,7 +36,7 @@ int mca_coll_ucc_barrier(struct ompi_communicator_t *comm, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc barrier"); - COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -49,7 +55,7 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, UCC_VERBOSE(3, "running ucc ibarrier"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_barrier_init(ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -61,3 +67,24 @@ int mca_coll_ucc_ibarrier(struct ompi_communicator_t *comm, return ucc_module->previous_ibarrier(comm, request, ucc_module->previous_ibarrier_module); } + +int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "barrier_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_barrier_init_common(true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback barrier_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_barrier_init(comm, info, request, + ucc_module->previous_barrier_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_bcast.c b/ompi/mca/coll/ucc/coll_ucc_bcast.c index 34296d74b97..8da3c839133 100644 --- a/ompi/mca/coll/ucc/coll_ucc_bcast.c +++ b/ompi/mca/coll/ucc/coll_ucc_bcast.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,20 +9,25 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dtype, - int root, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_bcast_init_common(void *buf, size_t count, struct ompi_datatype_t *dtype, + int root, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt = ompi_dtype_to_ucc_dtype(dtype); + uint64_t flags = 0; + if (COLL_UCC_DT_UNSUPPORTED == ucc_dt) { UCC_VERBOSE(5, "ompi_datatype is not supported: dtype = %s", dtype->super.name); goto fallback; } + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_BCAST, .root = root, .src.info = { @@ -31,6 +37,7 @@ static inline ucc_status_t mca_coll_ucc_bcast_init(void *buf, size_t count, stru .mem_type = UCC_MEMORY_TYPE_UNKNOWN } }; + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -44,15 +51,15 @@ int mca_coll_ucc_bcast(void *buf, size_t count, struct ompi_datatype_t *dtype, mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t*)module; ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc bcast"); - COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; fallback: UCC_VERBOSE(3, "running fallback bcast"); return ucc_module->previous_bcast(buf, count, dtype, root, - comm, ucc_module->previous_bcast_module); + comm, ucc_module->previous_bcast_module); } int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, @@ -66,8 +73,8 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, UCC_VERBOSE(3, "running ucc ibcast"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_bcast_init(buf, count, dtype, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -79,3 +86,26 @@ int mca_coll_ucc_ibcast(void *buf, size_t count, struct ompi_datatype_t *dtype, return ucc_module->previous_ibcast(buf, count, dtype, root, comm, request, ucc_module->previous_ibcast_module); } + +int mca_coll_ucc_bcast_init(void *buf, size_t count, struct ompi_datatype_t *dtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "bcast_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_bcast_init_common(buf, count, dtype, root, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback bcast_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_bcast_init(buf, count, dtype, root, comm, info, request, + ucc_module->previous_bcast_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_common.h b/ompi/mca/coll/ucc/coll_ucc_common.h index 9d9163aa46d..09bd9359a1e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_common.h +++ b/ompi/mca/coll/ucc/coll_ucc_common.h @@ -1,5 +1,6 @@ /** Copyright (c) 2021 Mellanox Technologies. All rights reserved. + Copyright (c) 2025 Fujitsu Limited. All rights reserved. $COPYRIGHT$ Additional copyrights may follow $HEADER$ @@ -42,6 +43,25 @@ _coll_req->super.req_type = OMPI_REQUEST_COLL; \ } while(0) +#define COLL_UCC_GET_REQ_PERSISTENT(_coll_req) \ + do { \ + opal_free_list_item_t *item; \ + item = opal_free_list_wait(&mca_coll_ucc_component.requests); \ + if (OPAL_UNLIKELY(NULL == item)) { \ + UCC_ERROR("failed to get mca_coll_ucc_req from free_list"); \ + goto fallback; \ + } \ + _coll_req = (mca_coll_ucc_req_t *) item; \ + OMPI_REQUEST_INIT(&_coll_req->super, true); \ + _coll_req->super.req_complete_cb = NULL; \ + _coll_req->super.req_complete_cb_data = NULL; \ + _coll_req->super.req_status.MPI_ERROR = MPI_SUCCESS; \ + _coll_req->super.req_free = mca_coll_ucc_req_free; \ + _coll_req->super.req_start = mca_coll_ucc_req_start; \ + _coll_req->super.req_type = OMPI_REQUEST_COLL; \ + _coll_req->ucc_req = NULL; \ + } while (0) + #define COLL_UCC_REQ_INIT(_coll_req, _req, _coll, _module) do{ \ if (_coll_req) { \ _coll.mask |= UCC_COLL_ARGS_FIELD_CB; \ @@ -76,5 +96,6 @@ static inline ucc_status_t coll_ucc_req_wait(ucc_coll_req_h req) int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req); void mca_coll_ucc_completion(void *data, ucc_status_t status); +int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests); #endif diff --git a/ompi/mca/coll/ucc/coll_ucc_component.c b/ompi/mca/coll/ucc/coll_ucc_component.c index 2f065c8404e..4fde1e0a999 100644 --- a/ompi/mca/coll/ucc/coll_ucc_component.c +++ b/ompi/mca/coll/ucc/coll_ucc_component.c @@ -3,6 +3,7 @@ * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -143,6 +144,58 @@ static ucc_coll_type_t mca_coll_ucc_str_to_type(const char *str) return UCC_COLL_TYPE_LAST; } +/* is a persistent collective */ +static inline int mca_coll_ucc_init_cts_is_persistent(const char *cp, char *bp, size_t bz) +{ + size_t len = strlen(cp), len_suffix = sizeof("_init") - 1; + + assert((bz > 0) && (bp != 0)); + /* check if it is a persistent collective */ + if (len > len_suffix) { + size_t blen = len - len_suffix; + const char *cp_suffix = &cp[blen]; + + if (0 == strcmp(cp_suffix, "_init")) { + int wc = snprintf(bp, bz, "%*.*s", (int)blen, (int)blen, cp); + if ((wc < 0) || ((size_t)wc >= bz)) { + return -1 /* XXX internal error */; + } + return 1 /* true */; + } + } + return 0 /* false */; +} + +/* is an alias (special) name */ +static inline int mca_coll_ucc_init_cts_is_alias(const char *cp, bool disable, + mca_coll_ucc_component_t *cm) +{ + if (0 == strcmp(cp, "colls_b")) { /* all blocking colls */ + if (disable) { + cm->cts_requested &= ~COLL_UCC_CTS; + } else { + cm->cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } else if ((0 == strcmp(cp, "colls_i")) || (0 == strcmp(cp, "colls_nb"))) { + /* all non-blocking colls */ + if (disable) { + cm->nb_cts_requested &= ~COLL_UCC_CTS; + } else { + cm->nb_cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } else if (0 == strcmp(cp, "colls_p")) { /* all persistent colls */ + if (disable) { + cm->ps_cts_requested &= ~COLL_UCC_CTS; + } else { + cm->ps_cts_requested |= COLL_UCC_CTS; + } + return 1 /* true */; + } + return 0 /* false */; +} + static void mca_coll_ucc_init_default_cts(void) { mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; @@ -157,11 +210,22 @@ static void mca_coll_ucc_init_default_cts(void) n_cts = opal_argv_count(cts); cm->cts_requested = disable ? COLL_UCC_CTS : 0; cm->nb_cts_requested = disable ? COLL_UCC_CTS : 0; + cm->ps_cts_requested = disable ? COLL_UCC_CTS : 0; for (i = 0; i < n_cts; i++) { + char l_str[64]; /* XXX sizeof("reduce_scatter_block") */ + size_t l_stz = sizeof(l_str); + + if (0 < mca_coll_ucc_init_cts_is_alias(cts[i], disable, cm)) { + continue; + } if (('i' == cts[i][0]) || ('I' == cts[i][0])) { /* non blocking collective setting */ str = cts[i] + 1; ct = &cm->nb_cts_requested; + } else if (0 < mca_coll_ucc_init_cts_is_persistent(cts[i], l_str, l_stz)) { + /* persistent collective setting */ + str = l_str; + ct = &cm->ps_cts_requested; } else { str = cts[i]; ct = &cm->cts_requested; diff --git a/ompi/mca/coll/ucc/coll_ucc_gather.c b/ompi/mca/coll/ucc/coll_ucc_gather.c index ba91b40b189..ad03d654b4c 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gather.c +++ b/ompi/mca/coll/ucc/coll_ucc_gather.c @@ -2,6 +2,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,17 +11,18 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, - int root, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_gather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, + int root, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (comm_rank == root) { if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) || @@ -53,9 +55,12 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om } } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_GATHER, .root = root, .src.info = { @@ -72,10 +77,6 @@ ucc_status_t mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct om }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -91,9 +92,9 @@ int mca_coll_ucc_gather(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gather"); - COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -116,9 +117,9 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t UCC_VERBOSE(3, "running ucc igather"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gather_init(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -131,3 +132,28 @@ int mca_coll_ucc_igather(const void *sbuf, size_t scount, struct ompi_datatype_t rdtype, root, comm, request, ucc_module->previous_igather_module); } + +int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "gather_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_gather_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, true, ucc_module, + &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback gather_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_gather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, + info, request, ucc_module->previous_gather_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_gatherv.c b/ompi/mca/coll/ucc/coll_ucc_gatherv.c index 5a1da52356c..abbdde5a77b 100644 --- a/ompi/mca/coll/ucc/coll_ucc_gatherv.c +++ b/ompi/mca/coll/ucc/coll_ucc_gatherv.c @@ -2,6 +2,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -10,12 +11,13 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, - void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_gatherv_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == sbuf); @@ -45,7 +47,8 @@ static inline ucc_status_t mca_coll_ucc_gatherv_init(const void *sbuf, size_t sc flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -83,9 +86,9 @@ int mca_coll_ucc_gatherv(const void *sbuf, size_t scount, struct ompi_datatype_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc gatherv"); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, - disps, rdtype, root, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -109,9 +112,9 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ UCC_VERBOSE(3, "running ucc igatherv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, - disps, rdtype, root, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -124,3 +127,30 @@ int mca_coll_ucc_igatherv(const void *sbuf, size_t scount, struct ompi_datatype_ disps, rdtype, root, comm, request, ucc_module->previous_igatherv_module); } + +int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, + struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "gatherv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_gatherv_init_common(sbuf, scount, sdtype, rbuf, rcounts, + disps, rdtype, root, true, ucc_module, + &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback gatherv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_gatherv_init(sbuf, scount, sdtype, rbuf, rcounts, disps, rdtype, + root, comm, info, request, + ucc_module->previous_gatherv_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index d297274e8c9..38901bc1403 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -4,6 +4,7 @@ * All Rights reserved. * Copyright (c) 2022-2025 NVIDIA Corporation. All rights reserved. * Copyright (c) 2024 Triad National Security, LLC. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -87,6 +88,34 @@ static void mca_coll_ucc_module_clear(mca_coll_ucc_module_t *ucc_module) ucc_module->previous_scatter_module = NULL; ucc_module->previous_iscatter = NULL; ucc_module->previous_iscatter_module = NULL; + ucc_module->previous_allreduce_init = NULL; + ucc_module->previous_allreduce_init_module = NULL; + ucc_module->previous_barrier_init = NULL; + ucc_module->previous_barrier_init_module = NULL; + ucc_module->previous_bcast_init = NULL; + ucc_module->previous_bcast_init_module = NULL; + ucc_module->previous_alltoall_init = NULL; + ucc_module->previous_alltoall_init_module = NULL; + ucc_module->previous_alltoallv_init = NULL; + ucc_module->previous_alltoallv_init_module = NULL; + ucc_module->previous_allgather_init = NULL; + ucc_module->previous_allgather_init_module = NULL; + ucc_module->previous_allgatherv_init = NULL; + ucc_module->previous_allgatherv_init_module = NULL; + ucc_module->previous_reduce_init = NULL; + ucc_module->previous_reduce_init_module = NULL; + ucc_module->previous_gather_init = NULL; + ucc_module->previous_gather_init_module = NULL; + ucc_module->previous_gatherv_init = NULL; + ucc_module->previous_gatherv_init_module = NULL; + ucc_module->previous_reduce_scatter_block_init = NULL; + ucc_module->previous_reduce_scatter_block_init_module = NULL; + ucc_module->previous_reduce_scatter_init = NULL; + ucc_module->previous_reduce_scatter_init_module = NULL; + ucc_module->previous_scatterv_init = NULL; + ucc_module->previous_scatterv_init_module = NULL; + ucc_module->previous_scatter_init = NULL; + ucc_module->previous_scatter_init_module = NULL; } static void mca_coll_ucc_module_construct(mca_coll_ucc_module_t *ucc_module) @@ -396,6 +425,12 @@ static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm) MCA_COLL_INSTALL_API(__comm, i##__api, mca_coll_ucc_i##__api, &__ucc_module->super, "ucc"); \ (__ucc_module)->super.coll_i##__api = mca_coll_ucc_i##__api; \ } \ + if (mca_coll_ucc_component.ps_cts_requested & UCC_COLL_TYPE_##__COLL) \ + { \ + MCA_COLL_SAVE_API(__comm, __api##_init, (__ucc_module)->previous_##__api##_init, (__ucc_module)->previous_##__api##_init_module, "ucc"); \ + MCA_COLL_INSTALL_API(__comm, __api##_init, mca_coll_ucc_##__api##_init, &__ucc_module->super, "ucc"); \ + (__ucc_module)->super.coll_##__api##_init = mca_coll_ucc_##__api##_init; \ + } \ } \ } while (0) @@ -530,11 +565,32 @@ mca_coll_ucc_module_disable(mca_coll_base_module_t *module, UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce); UCC_UNINSTALL_COLL_API(comm, ucc_module, gather); + UCC_UNINSTALL_COLL_API(comm, ucc_module, igather); UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, igatherv); UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter_block); UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, ireduce_scatter); UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatter); UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv); + UCC_UNINSTALL_COLL_API(comm, ucc_module, iscatterv); + + UCC_UNINSTALL_COLL_API(comm, ucc_module, allreduce_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, barrier_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, bcast_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoall_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, alltoallv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgather_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, allgatherv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gather_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, gatherv_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_block_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, reduce_scatter_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatter_init); + UCC_UNINSTALL_COLL_API(comm, ucc_module, scatterv_init); return OMPI_SUCCESS; } @@ -592,6 +648,19 @@ OBJ_CLASS_INSTANCE(mca_coll_ucc_req_t, ompi_request_t, int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req) { + { + mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t *) ompi_req[0]; + if (true == coll_req->super.req_persistent) { + UCC_VERBOSE(5, "%s free %p", "_init", coll_req); + if (NULL != coll_req->ucc_req) { + ucc_status_t rc_ucc; + rc_ucc = ucc_collective_finalize(coll_req->ucc_req); + if (UCC_OK != rc_ucc) { + UCC_ERROR("ucc_collective_finalize failed: %s", ucc_status_string(rc_ucc)); + } + } + } + } opal_free_list_return (&mca_coll_ucc_component.requests, (opal_free_list_item_t *)(*ompi_req)); *ompi_req = MPI_REQUEST_NULL; @@ -602,6 +671,56 @@ int mca_coll_ucc_req_free(struct ompi_request_t **ompi_req) void mca_coll_ucc_completion(void *data, ucc_status_t status) { mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t*)data; - ucc_collective_finalize(coll_req->ucc_req); + if (false == coll_req->super.req_persistent) { + ucc_collective_finalize(coll_req->ucc_req); + } else { + UCC_VERBOSE(5, "%s done %p", "_init", coll_req); + assert(!REQUEST_COMPLETE(&coll_req->super)); + } ompi_request_complete(&coll_req->super, true); } + +/* req_start() : ompi_request_start_fn_t */ +int mca_coll_ucc_req_start(size_t count, struct ompi_request_t **requests) +{ + size_t ii; + int rc = OMPI_SUCCESS; + + for (ii = 0; ii < count; ++ii) { + mca_coll_ucc_req_t *coll_req = (mca_coll_ucc_req_t *) requests[ii]; + ucc_status_t rc_ucc; + + if ((NULL == coll_req) || (OMPI_REQUEST_COLL != coll_req->super.req_type)) { + continue; + } + if (true != coll_req->super.req_persistent) { + coll_req->super.req_status.MPI_ERROR = MPI_ERR_REQUEST; + if (OMPI_SUCCESS == rc) { + rc = OMPI_ERROR; + } + continue; + } + UCC_VERBOSE(5, "%s post %p", "_init", coll_req); + assert(REQUEST_COMPLETE(&coll_req->super)); + assert(OMPI_REQUEST_INACTIVE == coll_req->super.req_state); + + coll_req->super.req_status.MPI_TAG = MPI_ANY_TAG; + coll_req->super.req_status.MPI_ERROR = OMPI_SUCCESS; + coll_req->super.req_status._cancelled = 0; + coll_req->super.req_complete = REQUEST_PENDING; + coll_req->super.req_state = OMPI_REQUEST_ACTIVE; + + rc_ucc = ucc_collective_post(coll_req->ucc_req); + if (UCC_OK != rc_ucc) { + UCC_ERROR("ucc_collective_post failed: %s", ucc_status_string(rc_ucc)); + coll_req->super.req_complete = REQUEST_COMPLETED; + coll_req->super.req_status.MPI_ERROR = MPI_ERR_INTERN; + if (OMPI_SUCCESS == rc) { + rc = OMPI_ERROR; + } + continue; + } + } + + return rc; +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce.c b/ompi/mca/coll/ucc/coll_ucc_reduce.c index 97b5d424ccf..c76b16c8881 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce.c @@ -1,5 +1,6 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -8,15 +9,16 @@ #include "coll_ucc_common.h" -static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t mca_coll_ucc_reduce_init_common(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; + uint64_t flags = 0; ucc_dt = ompi_dtype_to_ucc_dtype(dtype); ucc_op = ompi_op_to_ucc_op(op); @@ -30,9 +32,13 @@ static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf op->o_name); goto fallback; } + + flags = ((MPI_IN_PLACE == sbuf) ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_REDUCE, .root = root, .src.info = { @@ -49,10 +55,7 @@ static inline ucc_status_t mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf }, .op = ucc_op, }; - if (MPI_IN_PLACE == sbuf) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -69,8 +72,8 @@ int mca_coll_ucc_reduce(const void *sbuf, void* rbuf, size_t count, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op, - root, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, + root, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -93,8 +96,8 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, UCC_VERBOSE(3, "running ucc ireduce"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_init(sbuf, rbuf, count, dtype, op, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -106,3 +109,27 @@ int mca_coll_ucc_ireduce(const void *sbuf, void* rbuf, size_t count, return ucc_module->previous_ireduce(sbuf, rbuf, count, dtype, op, root, comm, request, ucc_module->previous_ireduce_module); } + +int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "reduce_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_init_common(sbuf, rbuf, count, dtype, op, root, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_reduce_init(sbuf, rbuf, count, dtype, op, root, comm, info, request, + ucc_module->previous_reduce_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c index dabc8f11d03..7ba6effb774 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,12 +10,12 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_reduce_scatter_init_common(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; @@ -47,7 +48,8 @@ ucc_status_t mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi total_count += ompi_count_array_get(rcounts, i); } - flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0); + flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -83,8 +85,8 @@ int mca_coll_ucc_reduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_t ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce_scatter"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, - op, ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -108,8 +110,8 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ UCC_VERBOSE(3, "running ucc ireduce_scatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, - op, ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -122,3 +124,28 @@ int mca_coll_ucc_ireduce_scatter(const void *sbuf, void *rbuf, ompi_count_array_ comm, request, ucc_module->previous_ireduce_scatter_module); } + +int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "reduce_scatter_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_init_common(sbuf, rbuf, rcounts, dtype, + op, true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_scatter_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module + ->previous_reduce_scatter_init(sbuf, rbuf, rcounts, dtype, op, comm, info, request, + ucc_module->previous_reduce_scatter_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c index 781776e42ca..49deba9393e 100644 --- a/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c +++ b/ompi/mca/coll/ucc/coll_ucc_reduce_scatter_block.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,18 +10,19 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, - size_t rcount, - struct ompi_datatype_t *dtype, - struct ompi_op_t *op, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_reduce_scatter_block_init_common(const void *sbuf, void *rbuf, + size_t rcount, + struct ompi_datatype_t *dtype, + struct ompi_op_t *op, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_dt; ucc_reduction_op_t ucc_op; int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (MPI_IN_PLACE == sbuf) { /* TODO: UCC defines inplace differently: @@ -40,9 +42,12 @@ ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf op->o_name); goto fallback; } + + flags = (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_REDUCE_SCATTER, .src.info = { .buffer = (void*)sbuf, @@ -58,6 +63,7 @@ ucc_status_t mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf }, .op = ucc_op, }; + COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -74,9 +80,9 @@ int mca_coll_ucc_reduce_scatter_block(const void *sbuf, void *rbuf, size_t rcoun ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc reduce scatter block"); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, - dtype, op, ucc_module, - &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, false, ucc_module, + &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -100,9 +106,9 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou UCC_VERBOSE(3, "running ucc ireduce_scatter_block"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init(sbuf, rbuf, rcount, - dtype, op, ucc_module, - &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, false, ucc_module, + &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -115,3 +121,30 @@ int mca_coll_ucc_ireduce_scatter_block(const void *sbuf, void *rbuf, size_t rcou op, comm, request, ucc_module->previous_ireduce_scatter_block_module); } + +int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount, + struct ompi_datatype_t *dtype, struct ompi_op_t *op, + struct ompi_communicator_t *comm, + struct ompi_info_t *info, ompi_request_t **request, + mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "reduce_scatter_block_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_reduce_scatter_block_init_common(sbuf, rbuf, rcount, + dtype, op, true, ucc_module, + &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback reduce_scatter_block_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module + ->previous_reduce_scatter_block_init(sbuf, rbuf, rcount, dtype, op, comm, info, request, + ucc_module->previous_reduce_scatter_block_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_scatter.c b/ompi/mca/coll/ucc/coll_ucc_scatter.c index 481365f22bd..4f4e60eaec3 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatter.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatter.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,19 +10,20 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, - struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_scatter_init_common(const void *sbuf, size_t scount, + struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); int comm_rank = ompi_comm_rank(ucc_module->comm); int comm_size = ompi_comm_size(ucc_module->comm); + uint64_t flags = 0; if (comm_rank == root) { if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(rdtype, rcount)) || @@ -54,9 +56,12 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, } } + flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); + ucc_coll_args_t coll = { - .mask = 0, - .flags = 0, + .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, + .flags = flags, .coll_type = UCC_COLL_TYPE_SCATTER, .root = root, .src.info = { @@ -73,10 +78,6 @@ ucc_status_t mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, }, }; - if (is_inplace) { - coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; - coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE; - } COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); return UCC_OK; fallback: @@ -93,9 +94,9 @@ int mca_coll_ucc_scatter(const void *sbuf, size_t scount, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatter"); - COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, &req, - NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, &req, + NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -120,9 +121,9 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, UCC_VERBOSE(3, "running ucc iscatter"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatter_init(sbuf, scount, sdtype, rbuf, rcount, - rdtype, root, ucc_module, &req, - coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, false, ucc_module, &req, + coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -135,3 +136,29 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, rdtype, root, comm, request, ucc_module->previous_iscatter_module); } + +int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "scatter_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_scatter_init_common(sbuf, scount, sdtype, rbuf, rcount, + rdtype, root, true, ucc_module, &req, + coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback scatter_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_scatter_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm, + info, request, + ucc_module->previous_scatter_init_module); +} diff --git a/ompi/mca/coll/ucc/coll_ucc_scatterv.c b/ompi/mca/coll/ucc/coll_ucc_scatterv.c index 36d4086a113..c1a611afd53 100644 --- a/ompi/mca/coll/ucc/coll_ucc_scatterv.c +++ b/ompi/mca/coll/ucc/coll_ucc_scatterv.c @@ -1,6 +1,7 @@ /** * Copyright (c) 2021 Mellanox Technologies. All rights reserved. * Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + * Copyright (c) 2025 Fujitsu Limited. All rights reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -9,14 +10,14 @@ #include "coll_ucc_common.h" -static inline -ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, - ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, - void *rbuf, size_t rcount, - struct ompi_datatype_t *rdtype, int root, - mca_coll_ucc_module_t *ucc_module, - ucc_coll_req_h *req, - mca_coll_ucc_req_t *coll_req) +static inline ucc_status_t +mca_coll_ucc_scatterv_init_common(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, + void *rbuf, size_t rcount, + struct ompi_datatype_t *rdtype, int root, + bool persistent, mca_coll_ucc_module_t *ucc_module, + ucc_coll_req_h *req, + mca_coll_ucc_req_t *coll_req) { ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; bool is_inplace = (MPI_IN_PLACE == rbuf); @@ -46,7 +47,8 @@ ucc_status_t mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t sco flags = (ompi_count_array_is_64bit(scounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) | (ompi_disp_array_is_64bit(disps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) | - (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0); + (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) | + (persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0); ucc_coll_args_t coll = { .mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0, @@ -85,9 +87,9 @@ int mca_coll_ucc_scatterv(const void *sbuf, ompi_count_array_t scounts, ucc_coll_req_h req; UCC_VERBOSE(3, "running ucc scatterv"); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_init(sbuf, scounts, disps, sdtype, - rbuf, rcount, rdtype, root, - ucc_module, &req, NULL)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + false, ucc_module, &req, NULL)); COLL_UCC_POST_AND_CHECK(req); COLL_UCC_CHECK(coll_ucc_req_wait(req)); return OMPI_SUCCESS; @@ -112,9 +114,9 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, UCC_VERBOSE(3, "running ucc iscatterv"); COLL_UCC_GET_REQ(coll_req); - COLL_UCC_CHECK(mca_coll_ucc_scatterv_init(sbuf, scounts, disps, sdtype, - rbuf, rcount, rdtype, root, - ucc_module, &req, coll_req)); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + false, ucc_module, &req, coll_req)); COLL_UCC_POST_AND_CHECK(req); *request = &coll_req->super; return OMPI_SUCCESS; @@ -127,3 +129,30 @@ int mca_coll_ucc_iscatterv(const void *sbuf, ompi_count_array_t scounts, rcount, rdtype, root, comm, request, ucc_module->previous_iscatterv_module); } + +int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, + ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, + size_t rcount, struct ompi_datatype_t *rdtype, int root, + struct ompi_communicator_t *comm, struct ompi_info_t *info, + ompi_request_t **request, mca_coll_base_module_t *module) +{ + mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; + ucc_coll_req_h req; + mca_coll_ucc_req_t *coll_req = NULL; + + COLL_UCC_GET_REQ_PERSISTENT(coll_req); + UCC_VERBOSE(3, "scatterv_init init %p", coll_req); + COLL_UCC_CHECK(mca_coll_ucc_scatterv_init_common(sbuf, scounts, disps, sdtype, + rbuf, rcount, rdtype, root, + true, ucc_module, &req, coll_req)); + *request = &coll_req->super; + return OMPI_SUCCESS; +fallback: + UCC_VERBOSE(3, "running fallback scatterv_init"); + if (coll_req) { + mca_coll_ucc_req_free((ompi_request_t **) &coll_req); + } + return ucc_module->previous_scatterv_init(sbuf, scounts, disps, sdtype, rbuf, rcount, rdtype, + root, comm, info, request, + ucc_module->previous_scatterv_init_module); +}