Skip to content

Commit 1861ca1

Browse files
committed
pr feedback
Signed-off-by: Howard Pritchard <[email protected]>
1 parent ee7b084 commit 1861ca1

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

ompi/mca/coll/base/coll_base_bcast.c

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
#include "coll_base_topo.h"
3636
#include "coll_base_util.h"
3737

38+
/*
39+
* if a > b return a- b otherwise 0
40+
*/
41+
static inline size_t
42+
rectify_diff(size_t a, size_t b)
43+
{
44+
return a > b ? a - b : 0;
45+
}
46+
3847
int
3948
ompi_coll_base_bcast_intra_generic( void* buffer,
4049
size_t original_count,
@@ -812,8 +821,11 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
812821
if (vrank & mask) {
813822
int parent = (rank - mask + comm_size) % comm_size;
814823
/* Compute an upper bound on recv block size */
824+
recv_count = rectify_diff(count, (size_t)(vrank * scatter_count));
825+
#if 0
815826
recv_count = (count > vrank * scatter_count) ? (count - vrank * scatter_count) : 0;
816-
if (recv_count <= 0) {
827+
#endif
828+
if (recv_count == 0) {
817829
curr_count = 0;
818830
} else {
819831
/* Recv data from parent */
@@ -833,7 +845,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
833845
mask >>= 1;
834846
while (mask > 0) {
835847
if (vrank + mask < comm_size) {
848+
send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask));
849+
#if 0
836850
send_count = (curr_count > scatter_count * mask) ? curr_count - scatter_count * mask : 0;
851+
#endif
837852
if (send_count > 0) {
838853
int child = (rank + mask) % comm_size;
839854
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -865,8 +880,11 @@ int ompi_coll_base_bcast_intra_scatter_allgather(
865880
if (vremote < comm_size) {
866881
ptrdiff_t send_offset = vrank_tree_root * scatter_count * extent;
867882
ptrdiff_t recv_offset = vremote_tree_root * scatter_count * extent;
883+
recv_count = rectify_diff(count, (size_t)(vremote_tree_root * scatter_count));
884+
#if 0
868885
recv_count = (count > vremote_tree_root * scatter_count) ?
869886
(count - vremote_tree_root * scatter_count) : 0;
887+
#endif
870888
err = ompi_coll_base_sendrecv((char *)buf + send_offset,
871889
curr_count, datatype, remote,
872890
MCA_COLL_BASE_TAG_BCAST,
@@ -986,7 +1004,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
9861004
if (vrank & mask) {
9871005
int parent = (rank - mask + comm_size) % comm_size;
9881006
/* Compute an upper bound on recv block size */
1007+
recv_count = rectify_diff(count, (size_t)(vrank * scatter_count));
1008+
#if 0
9891009
recv_count = (count > vrank * scatter_count) ? (count - vrank * scatter_count) : 0;
1010+
#endif
9901011
if (0 == recv_count) {
9911012
curr_count = 0;
9921013
} else {
@@ -1007,7 +1028,10 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10071028
mask >>= 1;
10081029
while (mask > 0) {
10091030
if (vrank + mask < comm_size) {
1031+
send_count = rectify_diff(curr_count, (size_t)(scatter_count * mask));
1032+
#if 0
10101033
send_count = (curr_count > scatter_count * mask) ? (curr_count - scatter_count * mask) : 0;
1034+
#endif
10111035
if (send_count > 0) {
10121036
int child = (rank + mask) % comm_size;
10131037
err = MCA_PML_CALL(send((char *)buf + (ptrdiff_t)scatter_count * (vrank + mask) * extent,
@@ -1032,16 +1056,22 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10321056
for (int i = 1; i < comm_size; ++i) {
10331057
/* how many elements remain in recv_block? */
10341058
size_t recv_offset_elems = recv_block * scatter_count;
1059+
size_t recv_remaining = rectify_diff(count, recv_offset_elems);
1060+
#if 0
10351061
size_t recv_remaining = (recv_offset_elems < count) ?
10361062
(count - recv_offset_elems) : 0;
1063+
#endif
10371064
recv_count = (recv_remaining < scatter_count) ?
10381065
recv_remaining : scatter_count;
10391066
size_t recv_offset = recv_offset_elems * extent;
10401067

10411068
/* same logic for send */
10421069
size_t send_offset_elems = send_block * scatter_count;
1070+
size_t send_remaining = rectify_diff(count, send_offset_elems);
1071+
#if 0
10431072
size_t send_remaining = (send_offset_elems < count) ?
10441073
(count - send_offset_elems) : 0;
1074+
#endif
10451075
send_count = (send_remaining < scatter_count) ?
10461076
send_remaining : scatter_count;
10471077
size_t send_offset = send_offset_elems * extent;
@@ -1051,13 +1081,13 @@ int ompi_coll_base_bcast_intra_scatter_allgather_ring(
10511081
(char*)buf + recv_offset, recv_count,
10521082
datatype, left, MCA_COLL_BASE_TAG_BCAST,
10531083
comm, MPI_STATUS_IGNORE, rank);
1054-
if (MPI_SUCCESS != err) {
1055-
goto cleanup_and_return;
1056-
}
1084+
if (MPI_SUCCESS != err) {
1085+
goto cleanup_and_return;
1086+
}
10571087

1058-
/* rotate blocks */
1059-
send_block = recv_block;
1060-
recv_block = (recv_block + comm_size - 1) % comm_size;
1088+
/* rotate blocks */
1089+
send_block = recv_block;
1090+
recv_block = (recv_block + comm_size - 1) % comm_size;
10611091
}
10621092

10631093
cleanup_and_return:

0 commit comments

Comments
 (0)