Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions ompi/mca/coll/ucc/coll_ucc.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/**
Copyright (c) 2021 Mellanox Technologies. All rights reserved.
Copyright (c) 2022 NVIDIA Corporation. All rights reserved.
Copyright (c) 2025 Fujitsu Limited. All rights reserved.
$COPYRIGHT$

Additional copyrights may follow
Expand Down Expand Up @@ -61,6 +62,7 @@ struct mca_coll_ucc_component_t {
ucc_lib_attr_t ucc_lib_attr;
ucc_coll_type_t cts_requested;
ucc_coll_type_t nb_cts_requested;
ucc_coll_type_t ps_cts_requested;
ucc_context_h ucc_context;
opal_free_list_t requests;
};
Expand Down Expand Up @@ -132,6 +134,34 @@ struct mca_coll_ucc_module_t {
mca_coll_base_module_t* previous_scatter_module;
mca_coll_base_module_iscatter_fn_t previous_iscatter;
mca_coll_base_module_t* previous_iscatter_module;
mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init;
mca_coll_base_module_t* previous_allreduce_init_module;
mca_coll_base_module_reduce_init_fn_t previous_reduce_init;
mca_coll_base_module_t* previous_reduce_init_module;
mca_coll_base_module_barrier_init_fn_t previous_barrier_init;
mca_coll_base_module_t* previous_barrier_init_module;
mca_coll_base_module_bcast_init_fn_t previous_bcast_init;
mca_coll_base_module_t* previous_bcast_init_module;
mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init;
mca_coll_base_module_t* previous_alltoall_init_module;
mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init;
mca_coll_base_module_t* previous_alltoallv_init_module;
mca_coll_base_module_allgather_init_fn_t previous_allgather_init;
mca_coll_base_module_t* previous_allgather_init_module;
mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init;
mca_coll_base_module_t* previous_allgatherv_init_module;
mca_coll_base_module_gather_init_fn_t previous_gather_init;
mca_coll_base_module_t* previous_gather_init_module;
mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init;
mca_coll_base_module_t* previous_gatherv_init_module;
mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init;
mca_coll_base_module_t* previous_reduce_scatter_block_init_module;
mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init;
mca_coll_base_module_t* previous_reduce_scatter_init_module;
mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init;
mca_coll_base_module_t* previous_scatterv_init_module;
mca_coll_base_module_scatter_init_fn_t previous_scatter_init;
mca_coll_base_module_t* previous_scatter_init_module;
};
typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t;
OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t);
Expand Down Expand Up @@ -305,5 +335,78 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount,
ompi_request_t** request,
mca_coll_base_module_t *module);

int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_bcast_init(void *buff, size_t count, struct ompi_datatype_t *datatype, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts,
ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype,
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
struct ompi_info_t *info, ompi_request_t **request,
mca_coll_base_module_t *module);

int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps,
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
struct ompi_info_t *info, ompi_request_t **request,
mca_coll_base_module_t *module);

int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps,
struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
struct ompi_info_t *info, ompi_request_t **request,
mca_coll_base_module_t *module);

int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts,
ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf,
size_t rcount, struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module);

END_C_DECLS
#endif
62 changes: 45 additions & 17 deletions ompi/mca/coll/ucc/coll_ucc_allgather.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2025 Fujitsu Limited. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -9,15 +10,17 @@

#include "coll_ucc_common.h"

static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
static inline ucc_status_t
mca_coll_ucc_allgather_init_common(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void* rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
bool persistent, mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
bool is_inplace = (MPI_IN_PLACE == sbuf);
int comm_size = ompi_comm_size(ucc_module->comm);
uint64_t flags = 0;

if (!(is_inplace || ompi_datatype_is_contiguous_memory_layout(sdtype, scount)) ||
!ompi_datatype_is_contiguous_memory_layout(rdtype, rcount * comm_size)) {
Expand All @@ -37,9 +40,12 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
goto fallback;
}

flags = (is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) |
(persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0);

ucc_coll_args_t coll = {
.mask = 0,
.flags = 0,
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
.flags = flags,
.coll_type = UCC_COLL_TYPE_ALLGATHER,
.src.info = {
.buffer = (void*)sbuf,
Expand All @@ -55,10 +61,6 @@ static inline ucc_status_t mca_coll_ucc_allgather_init(const void *sbuf, size_t
}
};

if (is_inplace) {
coll.mask = UCC_COLL_ARGS_FIELD_FLAGS;
coll.flags = UCC_COLL_ARGS_FLAG_IN_PLACE;
}
COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module);
return UCC_OK;
fallback:
Expand All @@ -74,9 +76,9 @@ int mca_coll_ucc_allgather(const void *sbuf, size_t scount, struct ompi_datatype
ucc_coll_req_h req;

UCC_VERBOSE(3, "running ucc allgather");
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
ucc_module, &req, NULL));
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
false, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -98,9 +100,9 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp

UCC_VERBOSE(3, "running ucc iallgather");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgather_init(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
ucc_module, &req, coll_req));
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand All @@ -112,3 +114,29 @@ int mca_coll_ucc_iallgather(const void *sbuf, size_t scount, struct ompi_datatyp
return ucc_module->previous_iallgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, request, ucc_module->previous_iallgather_module);
}

int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype,
struct ompi_communicator_t *comm, struct ompi_info_t *info,
ompi_request_t **request, mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
ucc_coll_req_h req;
mca_coll_ucc_req_t *coll_req = NULL;

COLL_UCC_GET_REQ_PERSISTENT(coll_req);
UCC_VERBOSE(3, "allgather_init init %p", coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgather_init_common(sbuf, scount, sdtype,
rbuf, rcount, rdtype,
true, ucc_module, &req, coll_req));
*request = &coll_req->super;
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgather_init");
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
}
return ucc_module->previous_allgather_init(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
info, request,
ucc_module->previous_allgather_init_module);
}
58 changes: 44 additions & 14 deletions ompi/mca/coll/ucc/coll_ucc_allgatherv.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

/**
* Copyright (c) 2021 Mellanox Technologies. All rights reserved.
* Copyright (c) 2025 Fujitsu Limited. All rights reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand All @@ -9,13 +10,14 @@

#include "coll_ucc_common.h"

static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount,
struct ompi_datatype_t *sdtype,
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
static inline ucc_status_t
mca_coll_ucc_allgatherv_init_common(const void *sbuf, size_t scount,
struct ompi_datatype_t *sdtype,
void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype,
bool persistent, mca_coll_ucc_module_t *ucc_module,
ucc_coll_req_h *req,
mca_coll_ucc_req_t *coll_req)
{
ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8;
bool is_inplace = (MPI_IN_PLACE == sbuf);
Expand All @@ -36,7 +38,8 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t

flags = (ompi_count_array_is_64bit(rcounts) ? UCC_COLL_ARGS_FLAG_COUNT_64BIT : 0) |
(ompi_disp_array_is_64bit(rdisps) ? UCC_COLL_ARGS_FLAG_DISPLACEMENTS_64BIT : 0) |
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0);
(is_inplace ? UCC_COLL_ARGS_FLAG_IN_PLACE : 0) |
(persistent ? UCC_COLL_ARGS_FLAG_PERSISTENT : 0);

ucc_coll_args_t coll = {
.mask = flags ? UCC_COLL_ARGS_FIELD_FLAGS : 0,
Expand Down Expand Up @@ -75,9 +78,9 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount,

UCC_VERBOSE(3, "running ucc allgatherv");

COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, NULL));
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
false, ucc_module, &req, NULL));
COLL_UCC_POST_AND_CHECK(req);
COLL_UCC_CHECK(coll_ucc_req_wait(req));
return OMPI_SUCCESS;
Expand All @@ -102,9 +105,9 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,

UCC_VERBOSE(3, "running ucc iallgatherv");
COLL_UCC_GET_REQ(coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
ucc_module, &req, coll_req));
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
false, ucc_module, &req, coll_req));
COLL_UCC_POST_AND_CHECK(req);
*request = &coll_req->super;
return OMPI_SUCCESS;
Expand All @@ -117,3 +120,30 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount,
rbuf, rcounts, rdisps, rdtype,
comm, request, ucc_module->previous_iallgatherv_module);
}

int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype,
void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps,
struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm,
struct ompi_info_t *info, ompi_request_t **request,
mca_coll_base_module_t *module)
{
mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module;
ucc_coll_req_h req;
mca_coll_ucc_req_t *coll_req = NULL;

COLL_UCC_GET_REQ_PERSISTENT(coll_req);
UCC_VERBOSE(3, "allgatherv_init init %p", coll_req);
COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init_common(sbuf, scount, sdtype,
rbuf, rcounts, rdisps, rdtype,
true, ucc_module, &req, coll_req));
*request = &coll_req->super;
return OMPI_SUCCESS;
fallback:
UCC_VERBOSE(3, "running fallback allgatherv_init");
if (coll_req) {
mca_coll_ucc_req_free((ompi_request_t **) &coll_req);
}
return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype,
comm, info, request,
ucc_module->previous_allgatherv_init_module);
}
Loading