-
Notifications
You must be signed in to change notification settings - Fork 928
COLL/UCC: add persistent collective calls for UCC #13374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
2dec57d
9305110
4171fdb
1b4bf91
a8b244e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| /** | ||
| Copyright (c) 2021 Mellanox Technologies. All rights reserved. | ||
| Copyright (c) 2022 NVIDIA Corporation. All rights reserved. | ||
| Copyright (c) 2025 Fujitsu Limited. All rights reserved. | ||
| $COPYRIGHT$ | ||
|
|
||
| Additional copyrights may follow | ||
|
|
@@ -61,6 +62,7 @@ struct mca_coll_ucc_component_t { | |
| ucc_lib_attr_t ucc_lib_attr; | ||
| ucc_coll_type_t cts_requested; | ||
| ucc_coll_type_t nb_cts_requested; | ||
| ucc_coll_type_t pc_cts_requested; | ||
| ucc_context_h ucc_context; | ||
| opal_free_list_t requests; | ||
| }; | ||
|
|
@@ -132,6 +134,34 @@ struct mca_coll_ucc_module_t { | |
| mca_coll_base_module_t* previous_scatter_module; | ||
| mca_coll_base_module_iscatter_fn_t previous_iscatter; | ||
| mca_coll_base_module_t* previous_iscatter_module; | ||
| mca_coll_base_module_allreduce_init_fn_t previous_allreduce_init; | ||
|
||
| mca_coll_base_module_t *previous_allreduce_init_module; | ||
| mca_coll_base_module_reduce_init_fn_t previous_reduce_init; | ||
| mca_coll_base_module_t *previous_reduce_init_module; | ||
| mca_coll_base_module_barrier_init_fn_t previous_barrier_init; | ||
| mca_coll_base_module_t *previous_barrier_init_module; | ||
| mca_coll_base_module_bcast_init_fn_t previous_bcast_init; | ||
| mca_coll_base_module_t *previous_bcast_init_module; | ||
| mca_coll_base_module_alltoall_init_fn_t previous_alltoall_init; | ||
| mca_coll_base_module_t *previous_alltoall_init_module; | ||
| mca_coll_base_module_alltoallv_init_fn_t previous_alltoallv_init; | ||
| mca_coll_base_module_t *previous_alltoallv_init_module; | ||
| mca_coll_base_module_allgather_init_fn_t previous_allgather_init; | ||
| mca_coll_base_module_t *previous_allgather_init_module; | ||
| mca_coll_base_module_allgatherv_init_fn_t previous_allgatherv_init; | ||
| mca_coll_base_module_t *previous_allgatherv_init_module; | ||
| mca_coll_base_module_gather_init_fn_t previous_gather_init; | ||
| mca_coll_base_module_t *previous_gather_init_module; | ||
| mca_coll_base_module_gatherv_init_fn_t previous_gatherv_init; | ||
| mca_coll_base_module_t *previous_gatherv_init_module; | ||
| mca_coll_base_module_reduce_scatter_block_init_fn_t previous_reduce_scatter_block_init; | ||
| mca_coll_base_module_t *previous_reduce_scatter_block_init_module; | ||
| mca_coll_base_module_reduce_scatter_init_fn_t previous_reduce_scatter_init; | ||
| mca_coll_base_module_t *previous_reduce_scatter_init_module; | ||
| mca_coll_base_module_scatterv_init_fn_t previous_scatterv_init; | ||
| mca_coll_base_module_t *previous_scatterv_init_module; | ||
| mca_coll_base_module_scatter_init_fn_t previous_scatter_init; | ||
| mca_coll_base_module_t *previous_scatter_init_module; | ||
| }; | ||
| typedef struct mca_coll_ucc_module_t mca_coll_ucc_module_t; | ||
| OBJ_CLASS_DECLARATION(mca_coll_ucc_module_t); | ||
|
|
@@ -305,5 +335,78 @@ int mca_coll_ucc_iscatter(const void *sbuf, size_t scount, | |
| ompi_request_t** request, | ||
| mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_allreduce_init(const void *sbuf, void *rbuf, size_t count, | ||
| struct ompi_datatype_t *dtype, struct ompi_op_t *op, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_reduce_init(const void *sbuf, void *rbuf, size_t count, | ||
| struct ompi_datatype_t *dtype, struct ompi_op_t *op, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_barrier_init(struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_bcast_init(void *buff, size_t count, struct ompi_datatype_t *datatype, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_alltoall_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_alltoallv_init(const void *sbuf, ompi_count_array_t scounts, | ||
| ompi_disp_array_t sdisps, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, | ||
| struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, | ||
| struct ompi_info_t *info, ompi_request_t **request, | ||
| mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_allgather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, | ||
| struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, | ||
| struct ompi_info_t *info, ompi_request_t **request, | ||
| mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_gather_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_gatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t disps, | ||
| struct ompi_datatype_t *rdtype, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_reduce_scatter_block_init(const void *sbuf, void *rbuf, size_t rcount, | ||
| struct ompi_datatype_t *dtype, struct ompi_op_t *op, | ||
| struct ompi_communicator_t *comm, | ||
| struct ompi_info_t *info, ompi_request_t **request, | ||
| mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_reduce_scatter_init(const void *sbuf, void *rbuf, ompi_count_array_t rcounts, | ||
| struct ompi_datatype_t *dtype, struct ompi_op_t *op, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_scatterv_init(const void *sbuf, ompi_count_array_t scounts, | ||
| ompi_disp_array_t disps, struct ompi_datatype_t *sdtype, void *rbuf, | ||
| size_t rcount, struct ompi_datatype_t *rdtype, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| int mca_coll_ucc_scatter_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, size_t rcount, struct ompi_datatype_t *rdtype, int root, | ||
| struct ompi_communicator_t *comm, struct ompi_info_t *info, | ||
| ompi_request_t **request, mca_coll_base_module_t *module); | ||
|
|
||
| END_C_DECLS | ||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
|
|
||
| /** | ||
| * Copyright (c) 2021 Mellanox Technologies. All rights reserved. | ||
| * Copyright (c) 2025 Fujitsu Limited. All rights reserved. | ||
| * $COPYRIGHT$ | ||
| * | ||
| * Additional copyrights may follow | ||
|
|
@@ -9,13 +10,12 @@ | |
|
|
||
| #include "coll_ucc_common.h" | ||
|
|
||
| static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, | ||
| struct ompi_datatype_t *sdtype, | ||
| void* rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, | ||
| struct ompi_datatype_t *rdtype, | ||
| mca_coll_ucc_module_t *ucc_module, | ||
| ucc_coll_req_h *req, | ||
| mca_coll_ucc_req_t *coll_req) | ||
| static inline ucc_status_t | ||
| mca_coll_ucc_allgatherv_iniz(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, | ||
| struct ompi_datatype_t *rdtype, bool persistent, | ||
| mca_coll_ucc_module_t *ucc_module, ucc_coll_req_h *req, | ||
| mca_coll_ucc_req_t *coll_req) | ||
| { | ||
| ucc_datatype_t ucc_sdt = UCC_DT_INT8, ucc_rdt = UCC_DT_INT8; | ||
| bool is_inplace = (MPI_IN_PLACE == sbuf); | ||
|
|
@@ -57,6 +57,10 @@ static inline ucc_status_t mca_coll_ucc_allgatherv_init(const void *sbuf, size_t | |
| } | ||
| }; | ||
|
|
||
| if (true == persistent) { | ||
|
||
| coll.mask |= UCC_COLL_ARGS_FIELD_FLAGS; | ||
| coll.flags |= UCC_COLL_ARGS_FLAG_PERSISTENT; | ||
| } | ||
| COLL_UCC_REQ_INIT(coll_req, req, coll, ucc_module); | ||
| return UCC_OK; | ||
| fallback: | ||
|
|
@@ -75,9 +79,8 @@ int mca_coll_ucc_allgatherv(const void *sbuf, size_t scount, | |
|
|
||
| UCC_VERBOSE(3, "running ucc allgatherv"); | ||
|
|
||
| COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, | ||
| rbuf, rcounts, rdisps, rdtype, | ||
| ucc_module, &req, NULL)); | ||
| COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, | ||
| false, ucc_module, &req, NULL)); | ||
| COLL_UCC_POST_AND_CHECK(req); | ||
| COLL_UCC_CHECK(coll_ucc_req_wait(req)); | ||
| return OMPI_SUCCESS; | ||
|
|
@@ -102,9 +105,8 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, | |
|
|
||
| UCC_VERBOSE(3, "running ucc iallgatherv"); | ||
| COLL_UCC_GET_REQ(coll_req); | ||
| COLL_UCC_CHECK(mca_coll_ucc_allgatherv_init(sbuf, scount, sdtype, | ||
| rbuf, rcounts, rdisps, rdtype, | ||
| ucc_module, &req, coll_req)); | ||
| COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, | ||
| false, ucc_module, &req, coll_req)); | ||
| COLL_UCC_POST_AND_CHECK(req); | ||
| *request = &coll_req->super; | ||
| return OMPI_SUCCESS; | ||
|
|
@@ -117,3 +119,29 @@ int mca_coll_ucc_iallgatherv(const void *sbuf, size_t scount, | |
| rbuf, rcounts, rdisps, rdtype, | ||
| comm, request, ucc_module->previous_iallgatherv_module); | ||
| } | ||
|
|
||
| int mca_coll_ucc_allgatherv_init(const void *sbuf, size_t scount, struct ompi_datatype_t *sdtype, | ||
| void *rbuf, ompi_count_array_t rcounts, ompi_disp_array_t rdisps, | ||
| struct ompi_datatype_t *rdtype, struct ompi_communicator_t *comm, | ||
| struct ompi_info_t *info, ompi_request_t **request, | ||
| mca_coll_base_module_t *module) | ||
| { | ||
| mca_coll_ucc_module_t *ucc_module = (mca_coll_ucc_module_t *) module; | ||
| ucc_coll_req_h req; | ||
| mca_coll_ucc_req_t *coll_req = NULL; | ||
|
|
||
| COLL_UCC_GET_REQ_PC(coll_req); | ||
| UCC_VERBOSE(3, "allgatherv_init init %p", coll_req); | ||
| COLL_UCC_CHECK(mca_coll_ucc_allgatherv_iniz(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, | ||
| true, ucc_module, &req, coll_req)); | ||
| *request = &coll_req->super; | ||
| return OMPI_SUCCESS; | ||
| fallback: | ||
| UCC_VERBOSE(3, "running fallback allgatherv_init"); | ||
| if (coll_req) { | ||
| mca_coll_ucc_req_free((ompi_request_t **) &coll_req); | ||
| } | ||
| return ucc_module->previous_allgatherv_init(sbuf, scount, sdtype, rbuf, rcounts, rdisps, rdtype, | ||
| comm, info, request, | ||
| ucc_module->previous_allgatherv_init_module); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does pc stand for persistent collective? maybe ps_cts_requested is better
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that
ps_cts_requestedis better. I changed pc_cts_requested tops_cts_requested.