From 2c9f58368a14400071d9064a7320431a2ff9ba67 Mon Sep 17 00:00:00 2001 From: George Bosilca Date: Sun, 22 Jun 2025 16:20:55 -0700 Subject: [PATCH] Allow UCC to be used with sessions And other instances where OMPI CIDs are not global. In this case, OMPI maintains a translation table for each communicator, but this tabe is not exposed to other software layers (such as UCC). As a result UCC must be coerced to create a unique ID for the team by itself. Signed-off-by: George Bosilca --- ompi/communicator/comm.c | 8 +++++-- ompi/mca/coll/ucc/coll_ucc_module.c | 33 +++++++++++++++++------------ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/ompi/communicator/comm.c b/ompi/communicator/comm.c index e7cdc68de63..af82a3e9d8c 100644 --- a/ompi/communicator/comm.c +++ b/ompi/communicator/comm.c @@ -2528,7 +2528,7 @@ int ompi_comm_determine_first_auto ( ompi_communicator_t* intercomm ) /********************************************************************************/ int ompi_comm_dump ( ompi_communicator_t *comm ) { - opal_output(0, "Dumping information for comm_cid %s\n", ompi_comm_print_cid (comm)); + opal_output(0, "Dumping information for comm_cid %s : %d\n", ompi_comm_print_cid (comm), ompi_comm_get_local_cid(comm)); opal_output(0," f2c index:%d cube_dim: %d\n", comm->c_f_to_c_index, comm->c_cube_dim); opal_output(0," Local group: size = %d my_rank = %d\n", @@ -2539,13 +2539,17 @@ int ompi_comm_dump ( ompi_communicator_t *comm ) /* Display flags */ if ( OMPI_COMM_IS_INTER(comm) ) opal_output(0," inter-comm,"); + else + opal_output(0," intra-comm,"); if ( OMPI_COMM_IS_CART(comm)) opal_output(0," topo-cart"); else if ( OMPI_COMM_IS_GRAPH(comm)) opal_output(0," topo-graph"); else if ( OMPI_COMM_IS_DIST_GRAPH(comm)) opal_output(0," topo-dist-graph"); - opal_output(0,"\n"); + else + opal_output(0, " no topo"); + opal_output(0,"\n"); if (OMPI_COMM_IS_INTER(comm)) { opal_output(0," Remote group size:%d\n", comm->c_remote_group->grp_proc_count); diff --git a/ompi/mca/coll/ucc/coll_ucc_module.c b/ompi/mca/coll/ucc/coll_ucc_module.c index c1b9fcf4f79..d297274e8c9 100644 --- a/ompi/mca/coll/ucc/coll_ucc_module.c +++ b/ompi/mca/coll/ucc/coll_ucc_module.c @@ -13,6 +13,7 @@ #include "ompi_config.h" #include "coll_ucc.h" +#include "coll_ucc_common.h" #include "coll_ucc_dtypes.h" #include "ompi/mca/coll/base/coll_tags.h" #include "ompi/mca/pml/pml.h" @@ -219,7 +220,8 @@ static ucc_status_t oob_allgather(void *sbuf, void *rbuf, size_t msglen, } -static int mca_coll_ucc_init_ctx() { +static int mca_coll_ucc_init_ctx(ompi_communicator_t* comm) +{ mca_coll_ucc_component_t *cm = &mca_coll_ucc_component; char str_buf[256]; ompi_attribute_fn_ptr_union_t del_fn; @@ -270,9 +272,9 @@ static int mca_coll_ucc_init_ctx() { ctx_params.oob.allgather = oob_allgather; ctx_params.oob.req_test = oob_allgather_test; ctx_params.oob.req_free = oob_allgather_free; - ctx_params.oob.coll_info = (void*)MPI_COMM_WORLD; - ctx_params.oob.n_oob_eps = ompi_comm_size(&ompi_mpi_comm_world.comm); - ctx_params.oob.oob_ep = ompi_comm_rank(&ompi_mpi_comm_world.comm); + ctx_params.oob.coll_info = (void*)comm; + ctx_params.oob.n_oob_eps = ompi_comm_size(comm); + ctx_params.oob.oob_ep = ompi_comm_rank(comm); if (UCC_OK != ucc_context_config_read(cm->ucc_lib, NULL, &ctx_config)) { UCC_ERROR("UCC context config read failed"); goto cleanup_lib; @@ -329,7 +331,7 @@ static int mca_coll_ucc_init_ctx() { return OMPI_ERROR; } -uint64_t rank_map_cb(uint64_t ep, void *cb_ctx) +static uint64_t rank_map_cb(uint64_t ep, void *cb_ctx) { struct ompi_communicator_t *comm = cb_ctx; @@ -433,8 +435,7 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, ucc_team_params_t team_params = { .mask = UCC_TEAM_PARAM_FIELD_EP_MAP | UCC_TEAM_PARAM_FIELD_EP | - UCC_TEAM_PARAM_FIELD_EP_RANGE | - UCC_TEAM_PARAM_FIELD_ID, + UCC_TEAM_PARAM_FIELD_EP_RANGE, .ep_map = { .type = (comm == &ompi_mpi_comm_world.comm) ? UCC_EP_MAP_FULL : UCC_EP_MAP_CB, @@ -443,12 +444,18 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module, .cb.cb_ctx = (void*)comm }, .ep = ompi_comm_rank(comm), - .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG, - .id = ompi_comm_get_local_cid(comm) + .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG }; - UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id %llu, comm_size %d", - (void*)comm, (long long unsigned)team_params.id, - ompi_comm_size(comm)); + if (OMPI_COMM_IS_GLOBAL_INDEX(comm)) { + team_params.mask |= UCC_TEAM_PARAM_FIELD_ID; + team_params.id = ompi_comm_get_local_cid(comm); + UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id %llu, comm_size %d", + (void*)comm, (long long unsigned)team_params.id, + ompi_comm_size(comm)); + } else { + UCC_VERBOSE(2, "creating ucc_team for comm %p, comm_id not provided, comm_size %d", + (void*)comm, ompi_comm_size(comm)); + } if (UCC_OK != ucc_team_create_post(&cm->ucc_context, 1, &team_params, &ucc_module->ucc_team)) { @@ -555,7 +562,7 @@ mca_coll_ucc_comm_query(struct ompi_communicator_t *comm, int *priority) } if (!cm->libucc_initialized) { - if (OMPI_SUCCESS != mca_coll_ucc_init_ctx()) { + if (OMPI_SUCCESS != mca_coll_ucc_init_ctx(comm)) { cm->ucc_enable = 0; return NULL; }