@@ -487,9 +487,9 @@ static int ompi_osc_ucx_shared_query_peer(ompi_osc_ucx_module_t *module, int pee
487
487
if (UCS_OK != ucp_rkey_ptr (rkey , module -> addrs [peer ], & addr_p )) {
488
488
return OMPI_ERR_NOT_AVAILABLE ;
489
489
}
490
- * size = module -> sizes [peer ];
491
- * ((void * * ) baseptr ) = ( void * ) module -> shmem_addrs [ peer ] ;
492
- * disp_unit = module -> disp_units [peer ];
490
+ * size = module -> same_size ? module -> size : module -> sizes [peer ];
491
+ * ((void * * ) baseptr ) = addr_p ;
492
+ * disp_unit = ( module -> disp_unit < 0 ) ? module -> disp_units [peer ] : module -> disp_unit ;
493
493
494
494
return OMPI_SUCCESS ;
495
495
}
@@ -554,8 +554,9 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
554
554
int flavor , int * model ) {
555
555
ompi_osc_ucx_module_t * module = NULL ;
556
556
char * name = NULL ;
557
- long values [2 ];
557
+ long values [4 ];
558
558
int ret = OMPI_SUCCESS ;
559
+ int val_count = 0 ;
559
560
int i , comm_size = ompi_comm_size (comm );
560
561
bool env_initialized = false;
561
562
void * state_base = NULL ;
@@ -679,42 +680,70 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, pt
679
680
module -> acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic" , info );
680
681
module -> skip_sync_check = false;
681
682
682
- /**
683
- * TODO: we need to collect the shared memory information from all processes
684
- * on the same node. This includes the size and base address, which needs
685
- * to be passed to ucp_rkey_ptr().
686
- */
687
- module -> shmem_info = NULL ;
688
-
689
683
/* share everyone's displacement units. Only do an allgather if
690
684
strictly necessary, since it requires O(p) state. */
691
685
values [0 ] = disp_unit ;
692
686
values [1 ] = - disp_unit ;
687
+ values [2 ] = size ;
688
+ values [3 ] = - (int64_t )size ;
693
689
694
- ret = module -> comm -> c_coll -> coll_allreduce (MPI_IN_PLACE , values , 2 , MPI_LONG ,
690
+ ret = module -> comm -> c_coll -> coll_allreduce (MPI_IN_PLACE , values , 4 , MPI_LONG ,
695
691
MPI_MIN , module -> comm ,
696
692
module -> comm -> c_coll -> coll_allreduce_module );
697
693
if (OMPI_SUCCESS != ret ) {
698
694
goto error ;
699
695
}
700
696
701
- if (values [0 ] == - values [1 ]) { /* everyone has the same disp_unit, we do not need O(p) space */
697
+ bool same_disp_unit = (values [0 ] == - values [1 ]);
698
+ bool same_size = (values [2 ] == - values [3 ]);
699
+
700
+ if (same_disp_unit ) { /* everyone has the same disp_unit, we do not need O(p) space */
702
701
module -> disp_unit = disp_unit ;
703
- } else { /* different disp_unit sizes, allocate O(p) space to store them */
704
- module -> disp_unit = -1 ;
705
- module -> disp_units = calloc (comm_size , sizeof (ptrdiff_t ));
706
- if (module -> disp_units == NULL ) {
707
- ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
708
- goto error ;
709
- }
702
+ module -> disp_units = NULL ;
703
+ values [val_count ++ ] = disp_unit ;
704
+ }
705
+
706
+ if (same_size ) {
707
+ module -> same_size = true;
708
+ module -> sizes = NULL ;
709
+ values [val_count ++ ] = size ;
710
+ }
711
+
712
+ if (!same_disp_unit || !same_size ) {
710
713
711
- ret = module -> comm -> c_coll -> coll_allgather (& disp_unit , sizeof (ptrdiff_t ), MPI_BYTE ,
712
- module -> disp_units , sizeof (ptrdiff_t ) , MPI_BYTE ,
713
- module -> comm ,
714
- module -> comm -> c_coll -> coll_allgather_module );
714
+ ret = module -> comm -> c_coll -> coll_allgather (values , val_count * sizeof (long ), MPI_BYTE ,
715
+ ( void * ) my_info , sizeof (long ) * val_count , MPI_BYTE ,
716
+ module -> comm ,
717
+ module -> comm -> c_coll -> coll_allgather_module );
715
718
if (OMPI_SUCCESS != ret ) {
716
719
goto error ;
717
720
}
721
+
722
+ if (!same_disp_unit ) { /* everyone has a different disp_unit */
723
+ module -> disp_unit = -1 ;
724
+ module -> disp_units = calloc (comm_size , sizeof (ptrdiff_t ));
725
+ if (module -> disp_units == NULL ) {
726
+ ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
727
+ goto error ;
728
+ }
729
+ for (i = 0 ; i < comm_size ; i ++ ) {
730
+ module -> disp_units [i ] = (ptrdiff_t )values [i * val_count ];
731
+ }
732
+ }
733
+
734
+ if (!same_size ) { /* everyone has the same disp_unit, we do not need O(p) space */
735
+ module -> same_size = false;
736
+ module -> sizes = calloc (comm_size , sizeof (size_t ));
737
+ if (module -> sizes == NULL ) {
738
+ ret = OMPI_ERR_TEMP_OUT_OF_RESOURCE ;
739
+ goto error ;
740
+ }
741
+
742
+ for (i = 0 ; i < comm_size ; i ++ ) {
743
+ module -> sizes [i ] = (size_t )values [i * val_count + val_count - 1 ];
744
+ }
745
+ }
746
+
718
747
}
719
748
720
749
ret = opal_common_ucx_wpctx_create (mca_osc_ucx_component .wpool , comm_size ,
@@ -1261,6 +1290,9 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) {
1261
1290
if (module -> disp_units ) {
1262
1291
free (module -> disp_units );
1263
1292
}
1293
+ if (module -> sizes ) {
1294
+ free (module -> sizes );
1295
+ }
1264
1296
ompi_comm_free (& module -> comm );
1265
1297
1266
1298
free (module );
0 commit comments