1717 * and Technology (RIST). All rights reserved.
1818 * Copyright (c) 2018 Siberian State University of Telecommunications
1919 * and Information Sciences. All rights reserved.
20+ * Copyright (c) 2022 IBM Corporation. All rights reserved.
2021 * $COPYRIGHT$
2122 *
2223 * Additional copyrights may follow
@@ -58,7 +59,8 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
5859 struct ompi_communicator_t * comm ,
5960 mca_coll_base_module_t * module )
6061{
61- int rank , size , count , err = OMPI_SUCCESS ;
62+ int rank , size , err = OMPI_SUCCESS ;
63+ size_t count ;
6264 ptrdiff_t gap , span ;
6365 char * recv_buf = NULL , * recv_buf_free = NULL ;
6466
@@ -67,40 +69,106 @@ ompi_coll_base_reduce_scatter_block_basic_linear(const void *sbuf, void *rbuf, i
6769 size = ompi_comm_size (comm );
6870
6971 /* short cut the trivial case */
70- count = rcount * size ;
72+ count = rcount * ( size_t ) size ;
7173 if (0 == count ) {
7274 return OMPI_SUCCESS ;
7375 }
7476
75- /* get datatype information */
76- span = opal_datatype_span (& dtype -> super , count , & gap );
77-
7877 /* Handle MPI_IN_PLACE */
7978 if (MPI_IN_PLACE == sbuf ) {
8079 sbuf = rbuf ;
8180 }
8281
83- if (0 == rank ) {
84- /* temporary receive buffer. See coll_basic_reduce.c for
85- details on sizing */
86- recv_buf_free = (char * ) malloc (span );
87- if (NULL == recv_buf_free ) {
88- err = OMPI_ERR_OUT_OF_RESOURCE ;
89- goto cleanup ;
82+ /*
83+ * For large payload (defined as a count greater than INT_MAX)
84+ * to reduce the memory footprint on the root we segment the
85+ * reductions per rank, then send to each rank.
86+ *
87+ * Additionally, sending the message in the coll_reduce() as
88+ * "rcount*size" would exceed the 'int count' parameter in the
89+ * coll_reduce() function. So another technique is required
90+ * for count values that exceed INT_MAX.
91+ */
92+ if ( OPAL_UNLIKELY (count > INT_MAX ) ) {
93+ int i ;
94+ void * sbuf_ptr ;
95+
96+ /* Get datatype information for an individual block */
97+ span = opal_datatype_span (& dtype -> super , rcount , & gap );
98+
99+ if (0 == rank ) {
100+ /* temporary receive buffer. See coll_basic_reduce.c for
101+ details on sizing */
102+ recv_buf_free = (char * ) malloc (span );
103+ if (NULL == recv_buf_free ) {
104+ err = OMPI_ERR_OUT_OF_RESOURCE ;
105+ goto cleanup ;
106+ }
107+ recv_buf = recv_buf_free - gap ;
108+ }
109+
110+ for ( i = 0 ; i < size ; ++ i ) {
111+ /* Calculate the portion of the send buffer to reduce over */
112+ sbuf_ptr = (char * )sbuf + span * (size_t )i ;
113+
114+ /* Reduction for this peer */
115+ err = comm -> c_coll -> coll_reduce (sbuf_ptr , recv_buf , rcount ,
116+ dtype , op , 0 , comm ,
117+ comm -> c_coll -> coll_reduce_module );
118+ if (MPI_SUCCESS != err ) {
119+ goto cleanup ;
120+ }
121+
122+ /* Send reduce results to this peer */
123+ if (0 == rank ) {
124+ if ( i == rank ) {
125+ err = ompi_datatype_copy_content_same_ddt (dtype , rcount , rbuf , recv_buf );
126+ } else {
127+ err = MCA_PML_CALL (send (recv_buf , rcount , dtype , i ,
128+ MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK ,
129+ MCA_PML_BASE_SEND_STANDARD , comm ));
130+ }
131+ if (MPI_SUCCESS != err ) {
132+ goto cleanup ;
133+ }
134+ }
135+ else if ( i == rank ) {
136+ err = MCA_PML_CALL (recv (rbuf , rcount , dtype , 0 ,
137+ MCA_COLL_BASE_TAG_REDUCE_SCATTER_BLOCK ,
138+ comm , MPI_STATUS_IGNORE ));
139+ if (MPI_SUCCESS != err ) {
140+ goto cleanup ;
141+ }
142+ }
90143 }
91- recv_buf = recv_buf_free - gap ;
92144 }
145+ else {
146+ /* get datatype information */
147+ span = opal_datatype_span (& dtype -> super , count , & gap );
148+
149+ if (0 == rank ) {
150+ /* temporary receive buffer. See coll_basic_reduce.c for
151+ details on sizing */
152+ recv_buf_free = (char * ) malloc (span );
153+ if (NULL == recv_buf_free ) {
154+ err = OMPI_ERR_OUT_OF_RESOURCE ;
155+ goto cleanup ;
156+ }
157+ recv_buf = recv_buf_free - gap ;
158+ }
93159
94- /* reduction */
95- err =
96- comm -> c_coll -> coll_reduce (sbuf , recv_buf , count , dtype , op , 0 ,
97- comm , comm -> c_coll -> coll_reduce_module );
160+ /* reduction */
161+ err =
162+ comm -> c_coll -> coll_reduce (sbuf , recv_buf , (int )count , dtype , op , 0 ,
163+ comm , comm -> c_coll -> coll_reduce_module );
164+ if (MPI_SUCCESS != err ) {
165+ goto cleanup ;
166+ }
98167
99- /* scatter */
100- if (MPI_SUCCESS == err ) {
168+ /* scatter */
101169 err = comm -> c_coll -> coll_scatter (recv_buf , rcount , dtype ,
102- rbuf , rcount , dtype , 0 ,
103- comm , comm -> c_coll -> coll_scatter_module );
170+ rbuf , rcount , dtype , 0 ,
171+ comm , comm -> c_coll -> coll_scatter_module );
104172 }
105173
106174 cleanup :
@@ -146,7 +214,16 @@ ompi_coll_base_reduce_scatter_block_intra_recursivedoubling(
146214 if (comm_size < 2 )
147215 return MPI_SUCCESS ;
148216
149- totalcount = comm_size * rcount ;
217+ totalcount = comm_size * (size_t )rcount ;
218+ if ( OPAL_UNLIKELY (totalcount > INT_MAX ) ) {
219+ /*
220+ * Large payload collectives are not supported by this algorithm.
221+ * The blocklens and displs calculations in the loop below
222+ * will overflow an int data type.
223+ * Fallback to the linear algorithm.
224+ */
225+ return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype , op , comm , module );
226+ }
150227 ompi_datatype_type_extent (dtype , & extent );
151228 span = opal_datatype_span (& dtype -> super , totalcount , & gap );
152229 tmpbuf_raw = malloc (span );
@@ -347,7 +424,8 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
347424 return ompi_coll_base_reduce_scatter_block_basic_linear (sbuf , rbuf , rcount , dtype ,
348425 op , comm , module );
349426 }
350- totalcount = comm_size * rcount ;
427+
428+ totalcount = comm_size * (size_t )rcount ;
351429 ompi_datatype_type_extent (dtype , & extent );
352430 span = opal_datatype_span (& dtype -> super , totalcount , & gap );
353431 tmpbuf_raw = malloc (span );
@@ -431,22 +509,22 @@ ompi_coll_base_reduce_scatter_block_intra_recursivehalving(
431509 * have their result calculated by the process to their
432510 * right (rank + 1).
433511 */
434- int send_count = 0 , recv_count = 0 ;
512+ size_t send_count = 0 , recv_count = 0 ;
435513 if (vrank < vpeer ) {
436514 /* Send the right half of the buffer, recv the left half */
437515 send_index = recv_index + mask ;
438- send_count = rcount * ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
439- recv_count = rcount * ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
516+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , last_index - 1 , nprocs_rem - 1 );
517+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , send_index - 1 , nprocs_rem - 1 );
440518 } else {
441519 /* Send the left half of the buffer, recv the right half */
442520 recv_index = send_index + mask ;
443- send_count = rcount * ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
444- recv_count = rcount * ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
521+ send_count = rcount * ( size_t ) ompi_range_sum (send_index , recv_index - 1 , nprocs_rem - 1 );
522+ recv_count = rcount * ( size_t ) ompi_range_sum (recv_index , last_index - 1 , nprocs_rem - 1 );
445523 }
446- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
447- 2 * recv_index : nprocs_rem + recv_index );
448- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
449- 2 * send_index : nprocs_rem + send_index );
524+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
525+ 2 * recv_index : nprocs_rem + recv_index );
526+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
527+ 2 * send_index : nprocs_rem + send_index );
450528 struct ompi_request_t * request = NULL ;
451529
452530 if (recv_count > 0 ) {
@@ -587,7 +665,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
587665 sbuf , rbuf , rcount , dtype , op , comm , module );
588666 }
589667
590- totalcount = comm_size * rcount ;
668+ totalcount = comm_size * ( size_t ) rcount ;
591669 ompi_datatype_type_extent (dtype , & extent );
592670 span = opal_datatype_span (& dtype -> super , totalcount , & gap );
593671 tmpbuf [0 ] = malloc (span );
@@ -677,13 +755,17 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
677755 /* Send the upper half of reduction buffer, recv the lower half */
678756 recv_index += nblocks ;
679757 }
680- int send_count = rcount * ompi_range_sum (send_index ,
681- send_index + nblocks - 1 , nprocs_rem - 1 );
682- int recv_count = rcount * ompi_range_sum (recv_index ,
683- recv_index + nblocks - 1 , nprocs_rem - 1 );
684- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
758+ size_t send_count = rcount *
759+ (size_t )ompi_range_sum (send_index ,
760+ send_index + nblocks - 1 ,
761+ nprocs_rem - 1 );
762+ size_t recv_count = rcount *
763+ (size_t )ompi_range_sum (recv_index ,
764+ recv_index + nblocks - 1 ,
765+ nprocs_rem - 1 );
766+ ptrdiff_t sdispl = rcount * (size_t )((send_index <= nprocs_rem - 1 ) ?
685767 2 * send_index : nprocs_rem + send_index );
686- ptrdiff_t rdispl = rcount * ((recv_index <= nprocs_rem - 1 ) ?
768+ ptrdiff_t rdispl = rcount * (size_t )( (recv_index <= nprocs_rem - 1 ) ?
687769 2 * recv_index : nprocs_rem + recv_index );
688770
689771 err = ompi_coll_base_sendrecv (psend + (ptrdiff_t )sdispl * extent , send_count ,
@@ -719,7 +801,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
719801 * Process has two blocks: for excluded process and own.
720802 * Send result to the excluded process.
721803 */
722- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
804+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
723805 2 * send_index : nprocs_rem + send_index );
724806 err = MCA_PML_CALL (send (psend + (ptrdiff_t )sdispl * extent ,
725807 rcount , dtype , peer - 1 ,
@@ -729,7 +811,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly(
729811 }
730812
731813 /* Send result to a remote process according to a mirror permutation */
732- ptrdiff_t sdispl = rcount * ((send_index <= nprocs_rem - 1 ) ?
814+ ptrdiff_t sdispl = rcount * (size_t )( (send_index <= nprocs_rem - 1 ) ?
733815 2 * send_index : nprocs_rem + send_index );
734816 /* If process has two blocks, then send the second block (own block) */
735817 if (vpeer < nprocs_rem )
@@ -821,7 +903,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
821903 if (rcount == 0 || comm_size < 2 )
822904 return MPI_SUCCESS ;
823905
824- totalcount = comm_size * rcount ;
906+ totalcount = comm_size * ( size_t ) rcount ;
825907 ompi_datatype_type_extent (dtype , & extent );
826908 span = opal_datatype_span (& dtype -> super , totalcount , & gap );
827909 tmpbuf [0 ] = malloc (span );
@@ -843,7 +925,7 @@ ompi_coll_base_reduce_scatter_block_intra_butterfly_pof2(
843925 if (MPI_SUCCESS != err ) { goto cleanup_and_return ; }
844926 }
845927
846- int nblocks = totalcount , send_index = 0 , recv_index = 0 ;
928+ size_t nblocks = totalcount , send_index = 0 , recv_index = 0 ;
847929 for (int mask = 1 ; mask < comm_size ; mask <<= 1 ) {
848930 int peer = rank ^ mask ;
849931 nblocks /= 2 ;
0 commit comments