3535#include "coll_base_topo.h"
3636#include "coll_base_util.h"
3737
38+ /*
39+ * if a > b return a- b otherwise 0
40+ */
41+ static inline size_t
42+ rectify_diff (size_t a , size_t b )
43+ {
44+ return a > b ? a - b : 0 ;
45+ }
46+
3847int
3948ompi_coll_base_bcast_intra_generic ( void * buffer ,
4049 size_t original_count ,
@@ -812,8 +821,11 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
812821 if (vrank & mask ) {
813822 int parent = (rank - mask + comm_size ) % comm_size ;
814823 /* Compute an upper bound on recv block size */
824+ recv_count = rectify_diff (count , (size_t )(vrank * scatter_count ));
825+ #if 0
815826 recv_count = (count > vrank * scatter_count ) ? (count - vrank * scatter_count ) : 0 ;
816- if (recv_count <= 0 ) {
827+ #endif
828+ if (recv_count == 0 ) {
817829 curr_count = 0 ;
818830 } else {
819831 /* Recv data from parent */
@@ -833,7 +845,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
833845 mask >>= 1 ;
834846 while (mask > 0 ) {
835847 if (vrank + mask < comm_size ) {
848+ send_count = rectify_diff (curr_count , (size_t )(scatter_count * mask ));
849+ #if 0
836850 send_count = (curr_count > scatter_count * mask ) ? curr_count - scatter_count * mask : 0 ;
851+ #endif
837852 if (send_count > 0 ) {
838853 int child = (rank + mask ) % comm_size ;
839854 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -865,8 +880,11 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
865880 if (vremote < comm_size ) {
866881 ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent ;
867882 ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent ;
883+ recv_count = rectify_diff (count , (size_t )(vremote_tree_root * scatter_count ));
884+ #if 0
868885 recv_count = (count > vremote_tree_root * scatter_count ) ?
869886 (count - vremote_tree_root * scatter_count ) : 0 ;
887+ #endif
870888 err = ompi_coll_base_sendrecv ((char * )buf + send_offset ,
871889 curr_count , datatype , remote ,
872890 MCA_COLL_BASE_TAG_BCAST ,
@@ -986,7 +1004,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
9861004 if (vrank & mask ) {
9871005 int parent = (rank - mask + comm_size ) % comm_size ;
9881006 /* Compute an upper bound on recv block size */
1007+ recv_count = rectify_diff (count , (size_t )(vrank * scatter_count ));
1008+ #if 0
9891009 recv_count = (count > vrank * scatter_count ) ? (count - vrank * scatter_count ) : 0 ;
1010+ #endif
9901011 if (0 == recv_count ) {
9911012 curr_count = 0 ;
9921013 } else {
@@ -1007,7 +1028,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10071028 mask >>= 1 ;
10081029 while (mask > 0 ) {
10091030 if (vrank + mask < comm_size ) {
1031+ send_count = rectify_diff (curr_count , (size_t )(scatter_count * mask ));
1032+ #if 0
10101033 send_count = (curr_count > scatter_count * mask ) ? (curr_count - scatter_count * mask ) : 0 ;
1034+ #endif
10111035 if (send_count > 0 ) {
10121036 int child = (rank + mask ) % comm_size ;
10131037 err = MCA_PML_CALL (send ((char * )buf + (ptrdiff_t )scatter_count * (vrank + mask ) * extent ,
@@ -1032,16 +1056,22 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10321056 for (int i = 1 ; i < comm_size ; ++ i ) {
10331057 /* how many elements remain in recv_block? */
10341058 size_t recv_offset_elems = recv_block * scatter_count ;
1059+ size_t recv_remaining = rectify_diff (count , recv_offset_elems );
1060+ #if 0
10351061 size_t recv_remaining = (recv_offset_elems < count ) ?
10361062 (count - recv_offset_elems ) : 0 ;
1063+ #endif
10371064 recv_count = (recv_remaining < scatter_count ) ?
10381065 recv_remaining : scatter_count ;
10391066 size_t recv_offset = recv_offset_elems * extent ;
10401067
10411068 /* same logic for send */
10421069 size_t send_offset_elems = send_block * scatter_count ;
1070+ size_t send_remaining = rectify_diff (count , send_offset_elems );
1071+ #if 0
10431072 size_t send_remaining = (send_offset_elems < count ) ?
10441073 (count - send_offset_elems ) : 0 ;
1074+ #endif
10451075 send_count = (send_remaining < scatter_count ) ?
10461076 send_remaining : scatter_count ;
10471077 size_t send_offset = send_offset_elems * extent ;
@@ -1051,13 +1081,13 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10511081 (char * )buf + recv_offset , recv_count ,
10521082 datatype , left , MCA_COLL_BASE_TAG_BCAST ,
10531083 comm , MPI_STATUS_IGNORE , rank );
1054- if (MPI_SUCCESS != err ) {
1055- goto cleanup_and_return ;
1056- }
1084+ if (MPI_SUCCESS != err ) {
1085+ goto cleanup_and_return ;
1086+ }
10571087
1058- /* rotate blocks */
1059- send_block = recv_block ;
1060- recv_block = (recv_block + comm_size - 1 ) % comm_size ;
1088+ /* rotate blocks */
1089+ send_block = recv_block ;
1090+ recv_block = (recv_block + comm_size - 1 ) % comm_size ;
10611091 }
10621092
10631093cleanup_and_return :
0 commit comments