Skip to content

Commit ef17dcb

Browse files
author
Valentin Petrov
committed
coll/ucc: use ep_map for team creation
Signed-off-by: Valentin Petrov <[email protected]>
1 parent 5782c0a commit ef17dcb

File tree

1 file changed

+57
-11
lines changed

1 file changed

+57
-11
lines changed

ompi/mca/coll/ucc/coll_ucc_module.c

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ static int mca_coll_ucc_init_ctx() {
291291
}
292292
ucc_context_config_release(ctx_config);
293293

294-
copy_fn.attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function*) MPI_COMM_NULL_COPY_FN;
294+
copy_fn.attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function)
295+
MPI_COMM_NULL_COPY_FN;
295296
del_fn.attr_communicator_delete_fn = ucc_comm_attr_del_fn;
296297
if (OMPI_SUCCESS != ompi_attr_create_keyval(COMM_ATTR, copy_fn, del_fn,
297298
&ucc_comm_attr_keyval, NULL ,0, NULL)) {
@@ -320,6 +321,52 @@ static int mca_coll_ucc_init_ctx() {
320321
return OMPI_ERROR;
321322
}
322323

324+
uint64_t rank_map_cb(uint64_t ep, void *cb_ctx)
325+
{
326+
struct ompi_communicator_t *comm = cb_ctx;
327+
328+
return ((ompi_process_name_t*)&ompi_comm_peer_lookup(comm, ep)->super.
329+
proc_name)->vpid;
330+
}
331+
332+
static inline ucc_ep_map_t get_rank_map(struct ompi_communicator_t *comm)
333+
{
334+
ucc_ep_map_t map;
335+
int64_t r1, r2, stride, i;
336+
int is_strided;
337+
338+
map.ep_num = ompi_comm_size(comm);
339+
if (comm == &ompi_mpi_comm_world.comm) {
340+
map.type = UCC_EP_MAP_FULL;
341+
return map;
342+
}
343+
344+
/* try to detect strided pattern */
345+
is_strided = 1;
346+
r1 = rank_map_cb(0, comm);
347+
r2 = rank_map_cb(1, comm);
348+
stride = r2 - r1;
349+
for (i = 2; i < map.ep_num; i++) {
350+
r1 = r2;
351+
r2 = rank_map_cb(i, comm);
352+
if (r2 - r1 != stride) {
353+
is_strided = 0;
354+
break;
355+
}
356+
}
357+
358+
if (is_strided) {
359+
map.type = UCC_EP_MAP_STRIDED;
360+
map.strided.start = r1;
361+
map.strided.stride = stride;
362+
} else {
363+
map.type = UCC_EP_MAP_CB;
364+
map.cb.cb = rank_map_cb;
365+
map.cb.cb_ctx = (void*)comm;
366+
}
367+
368+
return map;
369+
}
323370
/*
324371
* Initialize module on the communicator
325372
*/
@@ -331,16 +378,15 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
331378
ucc_status_t status;
332379
int rc;
333380
ucc_team_params_t team_params = {
334-
.mask = UCC_TEAM_PARAM_FIELD_EP |
335-
UCC_TEAM_PARAM_FIELD_EP_RANGE |
336-
UCC_TEAM_PARAM_FIELD_OOB,
337-
.oob = {
338-
.allgather = oob_allgather,
339-
.req_test = oob_allgather_test,
340-
.req_free = oob_allgather_free,
341-
.coll_info = (void*)comm,
342-
.n_oob_eps = ompi_comm_size(comm),
343-
.oob_ep = ompi_comm_rank(comm)
381+
.mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
382+
UCC_TEAM_PARAM_FIELD_EP |
383+
UCC_TEAM_PARAM_FIELD_EP_RANGE,
384+
.ep_map = {
385+
.type = (comm == &ompi_mpi_comm_world.comm) ?
386+
UCC_EP_MAP_FULL : UCC_EP_MAP_CB,
387+
.ep_num = ompi_comm_size(comm),
388+
.cb.cb = rank_map_cb,
389+
.cb.cb_ctx = (void*)comm
344390
},
345391
.ep = ompi_comm_rank(comm),
346392
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG

0 commit comments

Comments
 (0)