@@ -291,7 +291,8 @@ static int mca_coll_ucc_init_ctx() {
291291 }
292292 ucc_context_config_release (ctx_config );
293293
294- copy_fn .attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function * ) MPI_COMM_NULL_COPY_FN ;
294+ copy_fn .attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function )
295+ MPI_COMM_NULL_COPY_FN ;
295296 del_fn .attr_communicator_delete_fn = ucc_comm_attr_del_fn ;
296297 if (OMPI_SUCCESS != ompi_attr_create_keyval (COMM_ATTR , copy_fn , del_fn ,
297298 & ucc_comm_attr_keyval , NULL ,0 , NULL )) {
@@ -320,6 +321,52 @@ static int mca_coll_ucc_init_ctx() {
320321 return OMPI_ERROR ;
321322}
322323
324+ uint64_t rank_map_cb (uint64_t ep , void * cb_ctx )
325+ {
326+ struct ompi_communicator_t * comm = cb_ctx ;
327+
328+ return ((ompi_process_name_t * )& ompi_comm_peer_lookup (comm , ep )-> super .
329+ proc_name )-> vpid ;
330+ }
331+
332+ static inline ucc_ep_map_t get_rank_map (struct ompi_communicator_t * comm )
333+ {
334+ ucc_ep_map_t map ;
335+ int64_t r1 , r2 , stride , i ;
336+ int is_strided ;
337+
338+ map .ep_num = ompi_comm_size (comm );
339+ if (comm == & ompi_mpi_comm_world .comm ) {
340+ map .type = UCC_EP_MAP_FULL ;
341+ return map ;
342+ }
343+
344+ /* try to detect strided pattern */
345+ is_strided = 1 ;
346+ r1 = rank_map_cb (0 , comm );
347+ r2 = rank_map_cb (1 , comm );
348+ stride = r2 - r1 ;
349+ for (i = 2 ; i < map .ep_num ; i ++ ) {
350+ r1 = r2 ;
351+ r2 = rank_map_cb (i , comm );
352+ if (r2 - r1 != stride ) {
353+ is_strided = 0 ;
354+ break ;
355+ }
356+ }
357+
358+ if (is_strided ) {
359+ map .type = UCC_EP_MAP_STRIDED ;
360+ map .strided .start = r1 ;
361+ map .strided .stride = stride ;
362+ } else {
363+ map .type = UCC_EP_MAP_CB ;
364+ map .cb .cb = rank_map_cb ;
365+ map .cb .cb_ctx = (void * )comm ;
366+ }
367+
368+ return map ;
369+ }
323370/*
324371 * Initialize module on the communicator
325372 */
@@ -331,19 +378,20 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
331378 ucc_status_t status ;
332379 int rc ;
333380 ucc_team_params_t team_params = {
334- .mask = UCC_TEAM_PARAM_FIELD_EP |
335- UCC_TEAM_PARAM_FIELD_EP_RANGE |
336- UCC_TEAM_PARAM_FIELD_OOB ,
337- . oob = {
338- . allgather = oob_allgather ,
339- .req_test = oob_allgather_test ,
340- . req_free = oob_allgather_free ,
341- .coll_info = ( void * ) comm ,
342- .n_oob_eps = ompi_comm_size ( comm ) ,
343- .oob_ep = ompi_comm_rank ( comm )
381+ .mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
382+ UCC_TEAM_PARAM_FIELD_EP |
383+ UCC_TEAM_PARAM_FIELD_EP_RANGE |
384+ UCC_TEAM_PARAM_FIELD_ID ,
385+ . ep_map = {
386+ .type = ( comm == & ompi_mpi_comm_world . comm ) ?
387+ UCC_EP_MAP_FULL : UCC_EP_MAP_CB ,
388+ .ep_num = ompi_comm_size ( comm ) ,
389+ .cb . cb = rank_map_cb ,
390+ .cb . cb_ctx = ( void * ) comm
344391 },
345392 .ep = ompi_comm_rank (comm ),
346- .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG
393+ .ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG ,
394+ .id = comm -> c_contextid
347395 };
348396 UCC_VERBOSE (2 ,"creating ucc_team for comm %p, comm_id %d, comm_size %d" ,
349397 (void * )comm ,comm -> c_contextid ,ompi_comm_size (comm ));
0 commit comments