1414 * Copyright (c) 2016 Research Organization for Information Science
1515 * and Technology (RIST). All rights reserved.
1616 * Copyright (c) 2017 IBM Corporation. All rights reserved.
17+ * Copyright (c) 2025 Triad National Security, LLC. All rights reserved.
1718 * $COPYRIGHT$
1819 *
1920 * Additional copyrights may follow
@@ -811,7 +812,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
811812 if (vrank & mask ) {
812813 int parent = (rank - mask + comm_size ) % comm_size ;
813814 /* Compute an upper bound on recv block size */
814- recv_count = count - vrank * scatter_count ;
815+ recv_count = ( count > vrank * scatter_count ) ? ( count - vrank * scatter_count ) : 0 ;
815816 if (recv_count <= 0 ) {
816817 curr_count = 0 ;
817818 } else {
@@ -832,7 +833,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
832833 mask >>= 1 ;
833834 while (mask > 0 ) {
834835 if (vrank + mask < comm_size ) {
835- send_count = curr_count - scatter_count * mask ;
836+ send_count = ( curr_count > scatter_count * mask ) ? curr_count - scatter_count * mask : 0 ;
836837 if (send_count > 0 ) {
837838 int child = (rank + mask ) % comm_size ;
838839 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -850,10 +851,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
850851 * Allgather by recursive doubling
851852 * Each process has the curr_count elems in the buf[vrank * scatter_count, ...]
852853 */
853- size_t rem_count = count - vrank * scatter_count ;
854+ size_t rem_count = ( count > vrank * scatter_count ) ? count - vrank * scatter_count : 0 ;
854855 curr_count = (scatter_count < rem_count ) ? scatter_count : rem_count ;
855- if (curr_count < 0 )
856- curr_count = 0 ;
857856
858857 mask = 0x1 ;
859858 while (mask < comm_size ) {
@@ -866,9 +865,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
866865 if (vremote < comm_size ) {
867866 ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent ;
868867 ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent ;
869- recv_count = count - vremote_tree_root * scatter_count ;
870- if (recv_count < 0 )
871- recv_count = 0 ;
868+ recv_count = (count > vremote_tree_root * scatter_count ) ?
869+ (count - vremote_tree_root * scatter_count ) : 0 ;
872870 err = ompi_coll_base_sendrecv ((char * )buf + send_offset ,
873871 curr_count , datatype , remote ,
874872 MCA_COLL_BASE_TAG_BCAST ,
@@ -877,7 +875,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
877875 MCA_COLL_BASE_TAG_BCAST ,
878876 comm , & status , rank );
879877 if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
880- recv_count = (int )(status ._ucount / datatype_size );
878+ recv_count = (size_t )(status ._ucount / datatype_size );
881879 curr_count += recv_count ;
882880 }
883881
@@ -913,7 +911,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
913911 MCA_COLL_BASE_TAG_BCAST ,
914912 comm , & status ));
915913 if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
916- recv_count = (int )(status ._ucount / datatype_size );
914+ recv_count = (size_t )(status ._ucount / datatype_size );
917915 curr_count += recv_count ;
918916 }
919917 }
@@ -988,8 +986,8 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
988986 if (vrank & mask ) {
989987 int parent = (rank - mask + comm_size ) % comm_size ;
990988 /* Compute an upper bound on recv block size */
991- recv_count = count - vrank * scatter_count ;
992- if (recv_count <= 0 ) {
989+ recv_count = ( count > vrank * scatter_count ) ? ( count - vrank * scatter_count ) : 0 ;
990+ if (0 == recv_count ) {
993991 curr_count = 0 ;
994992 } else {
995993 /* Recv data from parent */
@@ -1009,7 +1007,7 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10091007 mask >>= 1 ;
10101008 while (mask > 0 ) {
10111009 if (vrank + mask < comm_size ) {
1012- send_count = curr_count - scatter_count * mask ;
1010+ send_count = ( curr_count > scatter_count * mask ) ? ( curr_count - scatter_count * mask ) : 0 ;
10131011 if (send_count > 0 ) {
10141012 int child = (rank + mask ) % comm_size ;
10151013 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -1023,33 +1021,43 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10231021 mask >>= 1 ;
10241022 }
10251023
1026- /* Allgather by a ring algorithm */
1024+ /* Allgather by a ring algorithm, using only unsigned types */
10271025 int left = (rank - 1 + comm_size ) % comm_size ;
10281026 int right = (rank + 1 ) % comm_size ;
1027+
1028+ /* The block we will send/recv in each step */
10291029 int send_block = vrank ;
10301030 int recv_block = (vrank - 1 + comm_size ) % comm_size ;
10311031
1032- for (int i = 1 ; i < comm_size ; i ++ ) {
1033- recv_count = (scatter_count < count - recv_block * scatter_count ) ?
1034- scatter_count : count - recv_block * scatter_count ;
1035- if (recv_count < 0 )
1036- recv_count = 0 ;
1037- ptrdiff_t recv_offset = recv_block * scatter_count * extent ;
1038-
1039- send_count = (scatter_count < count - send_block * scatter_count ) ?
1040- scatter_count : count - send_block * scatter_count ;
1041- if (send_count < 0 )
1042- send_count = 0 ;
1043- ptrdiff_t send_offset = send_block * scatter_count * extent ;
1044-
1045- err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
1032+ for (int i = 1 ; i < comm_size ; ++ i ) {
1033+ /* how many elements remain in recv_block? */
1034+ size_t recv_offset_elems = recv_block * scatter_count ;
1035+ size_t recv_remaining = (recv_offset_elems < count ) ?
1036+ (count - recv_offset_elems ) : 0 ;
1037+ recv_count = (recv_remaining < scatter_count ) ?
1038+ recv_remaining : scatter_count ;
1039+ size_t recv_offset = recv_offset_elems * extent ;
1040+
1041+ /* same logic for send */
1042+ size_t send_offset_elems = send_block * scatter_count ;
1043+ size_t send_remaining = (send_offset_elems < count ) ?
1044+ (count - send_offset_elems ) : 0 ;
1045+ send_count = (send_remaining < scatter_count ) ?
1046+ send_remaining : scatter_count ;
1047+ size_t send_offset = send_offset_elems * extent ;
1048+
1049+ err = ompi_coll_base_sendrecv ((char * )buf + send_offset , send_count ,
10461050 datatype , right , MCA_COLL_BASE_TAG_BCAST ,
1047- (char * )buf + recv_offset , recv_count ,
1051+ (char * )buf + recv_offset , recv_count ,
10481052 datatype , left , MCA_COLL_BASE_TAG_BCAST ,
10491053 comm , MPI_STATUS_IGNORE , rank );
1050- if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
1051- send_block = recv_block ;
1052- recv_block = (recv_block - 1 + comm_size ) % comm_size ;
1054+ if (MPI_SUCCESS != err ) {
1055+ goto cleanup_and_return ;
1056+ }
1057+
1058+ /* rotate blocks */
1059+ send_block = recv_block ;
1060+ recv_block = (recv_block + comm_size - 1 ) % comm_size ;
10531061 }
10541062
10551063cleanup_and_return :
0 commit comments