@@ -1268,10 +1268,9 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
12681268 mca_coll_base_module_t * module )
12691269{
12701270 char * send_buf = (void * ) sbuf ;
1271- int comm_size = ompi_comm_size (comm );
1271+ const int comm_size = ompi_comm_size (comm );
1272+ const int rank = ompi_comm_rank (comm );
12721273 int err = MPI_SUCCESS ;
1273- int rank = ompi_comm_rank (comm );
1274- bool commutative = ompi_op_is_commute (op );
12751274 ompi_request_t * * reqs ;
12761275
12771276 if (sbuf == MPI_IN_PLACE ) {
@@ -1288,24 +1287,30 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
12881287 return OMPI_ERR_OUT_OF_RESOURCE ;
12891288 }
12901289
1291- if (commutative ) {
1292- ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) send_buf );
1293- }
1294-
12951290 tmp_buf = tmp_buf_raw - gap ;
12961291
12971292 /* Requests for send to AND receive from everyone else */
12981293 int reqs_needed = (comm_size - 1 ) * 2 ;
12991294 reqs = ompi_coll_base_comm_get_reqs (module -> base_data , reqs_needed );
13001295
1301- ptrdiff_t incr = extent * count ;
1302- tmp_recv = (char * ) tmp_buf ;
1296+ const ptrdiff_t incr = extent * count ;
13031297
1304- /* Exchange data with peer processes */
1298+ /* Exchange data with peer processes, excluding self */
13051299 int req_index = 0 , peer_rank = 0 ;
13061300 for (int i = 1 ; i < comm_size ; ++ i ) {
1301+ /* Start at the next rank */
13071302 peer_rank = (rank + i ) % comm_size ;
1308- tmp_recv = tmp_buf + (peer_rank * incr );
1303+
1304+ /* Prepare for the next receive buffer */
1305+ if (0 == peer_rank && rbuf != send_buf ) {
1306+ /* Optimization for Rank 0 - its data will always be placed at the beginning of local
1307+ * reduce output buffer.
1308+ */
1309+ tmp_recv = rbuf ;
1310+ } else {
1311+ tmp_recv = tmp_buf + (peer_rank * incr );
1312+ }
1313+
13091314 err = MCA_PML_CALL (irecv (tmp_recv , count , dtype , peer_rank , MCA_COLL_BASE_TAG_ALLREDUCE ,
13101315 comm , & reqs [req_index ++ ]));
13111316 if (MPI_SUCCESS != err ) {
@@ -1321,17 +1326,29 @@ int ompi_coll_base_allreduce_intra_allgather_reduce(const void *sbuf, void *rbuf
13211326
13221327 err = ompi_request_wait_all (req_index , reqs , MPI_STATUSES_IGNORE );
13231328
1324- /* Prepare for local reduction */
1325- peer_rank = 0 ;
1326- if (!commutative ) {
1327- /* For non-commutative operations, ensure the reduction always starts from Rank 0's data */
1328- memcpy (rbuf , 0 == rank ? send_buf : tmp_buf , incr );
1329- peer_rank = 1 ;
1329+ /**
1330+ * Prepare for local reduction by moving Rank 0's data to rbuf.
1331+ * Previously we tried to receive Rank 0's data in rbuf, but we need to handle
1332+ * the following special cases.
1333+ */
1334+ if (0 != rank && rbuf == send_buf ) {
1335+ /* For inplace reduction copy out the send_buf before moving Rank 0's data */
1336+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) tmp_buf + (rank * incr ),
1337+ send_buf );
1338+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) tmp_buf );
1339+ } else if (0 == rank && rbuf != send_buf ) {
1340+ /* For Rank 0 we need to copy the send_buf to rbuf manually */
1341+ ompi_datatype_copy_content_same_ddt (dtype , count , (char * ) rbuf , (char * ) send_buf );
13301342 }
13311343
1332- char * inbuf ;
1333- for (; peer_rank < comm_size ; peer_rank ++ ) {
1334- inbuf = rank == peer_rank ? send_buf : tmp_buf + (peer_rank * incr );
1344+ /* Now do local reduction - Rank 0's data is already in rbuf so start from Rank 1 */
1345+ char * inbuf = NULL ;
1346+ for (peer_rank = 1 ; peer_rank < comm_size ; peer_rank ++ ) {
1347+ if (rank == peer_rank && rbuf != send_buf ) {
1348+ inbuf = send_buf ;
1349+ } else {
1350+ inbuf = tmp_buf + (peer_rank * incr );
1351+ }
13351352 ompi_op_reduce (op , (void * ) inbuf , rbuf , count , dtype );
13361353 }
13371354
0 commit comments