@@ -291,7 +291,8 @@ static int mca_coll_ucc_init_ctx() {
291
291
}
292
292
ucc_context_config_release (ctx_config );
293
293
294
- copy_fn .attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function * ) MPI_COMM_NULL_COPY_FN ;
294
+ copy_fn .attr_communicator_copy_fn = (MPI_Comm_internal_copy_attr_function )
295
+ MPI_COMM_NULL_COPY_FN ;
295
296
del_fn .attr_communicator_delete_fn = ucc_comm_attr_del_fn ;
296
297
if (OMPI_SUCCESS != ompi_attr_create_keyval (COMM_ATTR , copy_fn , del_fn ,
297
298
& ucc_comm_attr_keyval , NULL ,0 , NULL )) {
@@ -320,6 +321,52 @@ static int mca_coll_ucc_init_ctx() {
320
321
return OMPI_ERROR ;
321
322
}
322
323
324
+ uint64_t rank_map_cb (uint64_t ep , void * cb_ctx )
325
+ {
326
+ struct ompi_communicator_t * comm = cb_ctx ;
327
+
328
+ return ((ompi_process_name_t * )& ompi_comm_peer_lookup (comm , ep )-> super .
329
+ proc_name )-> vpid ;
330
+ }
331
+
332
+ static inline ucc_ep_map_t get_rank_map (struct ompi_communicator_t * comm )
333
+ {
334
+ ucc_ep_map_t map ;
335
+ int64_t r1 , r2 , stride , i ;
336
+ int is_strided ;
337
+
338
+ map .ep_num = ompi_comm_size (comm );
339
+ if (comm == & ompi_mpi_comm_world .comm ) {
340
+ map .type = UCC_EP_MAP_FULL ;
341
+ return map ;
342
+ }
343
+
344
+ /* try to detect strided pattern */
345
+ is_strided = 1 ;
346
+ r1 = rank_map_cb (0 , comm );
347
+ r2 = rank_map_cb (1 , comm );
348
+ stride = r2 - r1 ;
349
+ for (i = 2 ; i < map .ep_num ; i ++ ) {
350
+ r1 = r2 ;
351
+ r2 = rank_map_cb (i , comm );
352
+ if (r2 - r1 != stride ) {
353
+ is_strided = 0 ;
354
+ break ;
355
+ }
356
+ }
357
+
358
+ if (is_strided ) {
359
+ map .type = UCC_EP_MAP_STRIDED ;
360
+ map .strided .start = r1 ;
361
+ map .strided .stride = stride ;
362
+ } else {
363
+ map .type = UCC_EP_MAP_CB ;
364
+ map .cb .cb = rank_map_cb ;
365
+ map .cb .cb_ctx = (void * )comm ;
366
+ }
367
+
368
+ return map ;
369
+ }
323
370
/*
324
371
* Initialize module on the communicator
325
372
*/
@@ -331,16 +378,15 @@ static int mca_coll_ucc_module_enable(mca_coll_base_module_t *module,
331
378
ucc_status_t status ;
332
379
int rc ;
333
380
ucc_team_params_t team_params = {
334
- .mask = UCC_TEAM_PARAM_FIELD_EP |
335
- UCC_TEAM_PARAM_FIELD_EP_RANGE |
336
- UCC_TEAM_PARAM_FIELD_OOB ,
337
- .oob = {
338
- .allgather = oob_allgather ,
339
- .req_test = oob_allgather_test ,
340
- .req_free = oob_allgather_free ,
341
- .coll_info = (void * )comm ,
342
- .n_oob_eps = ompi_comm_size (comm ),
343
- .oob_ep = ompi_comm_rank (comm )
381
+ .mask = UCC_TEAM_PARAM_FIELD_EP_MAP |
382
+ UCC_TEAM_PARAM_FIELD_EP |
383
+ UCC_TEAM_PARAM_FIELD_EP_RANGE ,
384
+ .ep_map = {
385
+ .type = (comm == & ompi_mpi_comm_world .comm ) ?
386
+ UCC_EP_MAP_FULL : UCC_EP_MAP_CB ,
387
+ .ep_num = ompi_comm_size (comm ),
388
+ .cb .cb = rank_map_cb ,
389
+ .cb .cb_ctx = (void * )comm
344
390
},
345
391
.ep = ompi_comm_rank (comm ),
346
392
.ep_range = UCC_COLLECTIVE_EP_RANGE_CONTIG
0 commit comments