1717 * Copyright (c) 2017-2022 IBM Corporation. All rights reserved.
1818 * Copyright (c) 2021 Amazon.com, Inc. or its affiliates. All Rights
1919 * reserved.
20+ * Copyright (c) 2022 BULL S.A.S. All rights reserved.
2021 * $COPYRIGHT$
2122 *
2223 * Additional copyrights may follow
@@ -222,8 +223,8 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
222223 struct ompi_communicator_t * comm ,
223224 mca_coll_base_module_t * module )
224225{
225- int i , k , line = -1 , rank , size , err = 0 ;
226- int sendto , recvfrom , distance , * displs = NULL , * blen = NULL ;
226+ int i , line = -1 , rank , size , err = 0 ;
227+ int sendto , recvfrom , distance , * displs = NULL ;
227228 char * tmpbuf = NULL , * tmpbuf_free = NULL ;
228229 ptrdiff_t sext , rext , span , gap = 0 ;
229230 struct ompi_datatype_t * new_ddt ;
@@ -245,31 +246,31 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
245246 err = ompi_datatype_type_extent (rdtype , & rext );
246247 if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
247248
248- span = opal_datatype_span (& sdtype -> super , (int64_t )size * scount , & gap );
249-
250- displs = (int * ) malloc (size * sizeof (int ));
251- if (displs == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
252- blen = (int * ) malloc (size * sizeof (int ));
253- if (blen == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
249+ span = opal_datatype_span (& rdtype -> super , (int64_t )size * rcount , & gap );
254250
255251 /* tmp buffer allocation for message data */
256252 tmpbuf_free = (char * )malloc (span );
257253 if (tmpbuf_free == NULL ) { line = __LINE__ ; err = -1 ; goto err_hndl ; }
258254 tmpbuf = tmpbuf_free - gap ;
259255
260256 /* Step 1 - local rotation - shift up by rank */
261- err = ompi_datatype_copy_content_same_ddt (sdtype ,
262- (int32_t ) ((ptrdiff_t )(size - rank ) * (ptrdiff_t )scount ),
263- tmpbuf ,
264- ((char * ) sbuf ) + (ptrdiff_t )rank * (ptrdiff_t )scount * sext );
257+ err = ompi_datatype_sndrcv (sbuf + ((ptrdiff_t ) rank * scount * sext ),
258+ (int32_t ) (size - rank ) * scount ,
259+ sdtype ,
260+ tmpbuf ,
261+ (int32_t ) (size - rank ) * rcount ,
262+ rdtype );
265263 if (err < 0 ) {
266264 line = __LINE__ ; err = -1 ; goto err_hndl ;
267265 }
268266
269267 if (rank != 0 ) {
270- err = ompi_datatype_copy_content_same_ddt (sdtype , (ptrdiff_t )rank * (ptrdiff_t )scount ,
271- tmpbuf + (ptrdiff_t )(size - rank ) * (ptrdiff_t )scount * sext ,
272- (char * ) sbuf );
268+ err = ompi_datatype_sndrcv (sbuf ,
269+ (int32_t ) rank * scount ,
270+ sdtype ,
271+ tmpbuf + ((ptrdiff_t ) (size - rank ) * rcount * rext ),
272+ (int32_t ) rank * rcount ,
273+ rdtype );
273274 if (err < 0 ) {
274275 line = __LINE__ ; err = -1 ; goto err_hndl ;
275276 }
@@ -280,19 +281,19 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
280281
281282 sendto = (rank + distance ) % size ;
282283 recvfrom = (rank - distance + size ) % size ;
283- k = 0 ;
284-
285- /* create indexed datatype */
286- for ( i = 1 ; i < size ; i ++ ) {
287- if (( i & distance ) == distance ) {
288- displs [ k ] = ( ptrdiff_t ) i * ( ptrdiff_t ) scount ;
289- blen [ k ] = scount ;
290- k ++ ;
284+
285+ new_ddt = ompi_datatype_create (( 1 + size / distance ) * ( 2 + rdtype -> super . desc . used ));
286+
287+ /* Create datatype describing data sent/received */
288+ for ( i = distance ; i < size ; i += 2 * distance ) {
289+ int nblocks = distance ;
290+ if ( i + distance >= size ) {
291+ nblocks = size - i ;
291292 }
293+ ompi_datatype_add (new_ddt , rdtype , rcount * nblocks ,
294+ i * rcount * rext , rext );
292295 }
293- /* Set indexes and displacements */
294- err = ompi_datatype_create_indexed (k , blen , displs , sdtype , & new_ddt );
295- if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
296+
296297 /* Commit the new datatype */
297298 err = ompi_datatype_commit (& new_ddt );
298299 if (err != MPI_SUCCESS ) { line = __LINE__ ; goto err_hndl ; }
@@ -324,19 +325,16 @@ int ompi_coll_base_alltoall_intra_bruck(const void *sbuf, int scount,
324325 }
325326
326327 /* Step 4 - clean up */
327- if (tmpbuf != NULL ) free (tmpbuf_free );
328- if (displs != NULL ) free (displs );
329- if (blen != NULL ) free (blen );
328+ if (tmpbuf_free != NULL ) free (tmpbuf_free );
330329 return OMPI_SUCCESS ;
331330
332331 err_hndl :
333332 OPAL_OUTPUT ((ompi_coll_base_framework .framework_output ,
334333 "%s:%4d\tError occurred %d, rank %2d" , __FILE__ , line , err ,
335334 rank ));
336335 (void )line ; // silence compiler warning
337- if (tmpbuf != NULL ) free (tmpbuf_free );
336+ if (tmpbuf_free != NULL ) free (tmpbuf_free );
338337 if (displs != NULL ) free (displs );
339- if (blen != NULL ) free (blen );
340338 return err ;
341339}
342340
0 commit comments