88 * Copyright (c) 2013-2015 Los Alamos National Security, LLC. All rights
99 * reserved.
1010 * Copyright (c) 2014 NVIDIA Corporation. All rights reserved.
11- * Copyright (c) 2014-2015 Research Organization for Information Science
11+ * Copyright (c) 2014-2016 Research Organization for Information Science
1212 * and Technology (RIST). All rights reserved.
1313 *
1414 * Author(s): Torsten Hoefler <[email protected] > @@ -25,6 +25,8 @@ static inline int a2a_sched_pairwise(int rank, int p, MPI_Aint sndext, MPI_Aint
2525static inline int a2a_sched_diss (int rank , int p , MPI_Aint sndext , MPI_Aint rcvext , NBC_Schedule * schedule ,
2626 const void * sendbuf , int sendcount , MPI_Datatype sendtype , void * recvbuf ,
2727 int recvcount , MPI_Datatype recvtype , MPI_Comm comm , NBC_Handle * handle );
28+ static inline int a2a_sched_inplace (int rank , int p , NBC_Schedule * schedule , void * buf , int count ,
29+ MPI_Datatype type , MPI_Aint ext , ptrdiff_t gap , MPI_Comm comm );
2830
2931#ifdef NBC_CACHE_SCHEDULE
3032/* tree comparison function for schedule cache */
@@ -59,9 +61,10 @@ int ompi_coll_libnbc_ialltoall(const void* sendbuf, int sendcount, MPI_Datatype
5961 NBC_Alltoall_args * args , * found , search ;
6062#endif
6163 char * rbuf , * sbuf , inplace ;
62- enum {NBC_A2A_LINEAR , NBC_A2A_PAIRWISE , NBC_A2A_DISS } alg ;
64+ enum {NBC_A2A_LINEAR , NBC_A2A_PAIRWISE , NBC_A2A_DISS , NBC_A2A_INPLACE } alg ;
6365 NBC_Handle * handle ;
6466 ompi_coll_libnbc_module_t * libnbc_module = (ompi_coll_libnbc_module_t * ) module ;
67+ ptrdiff_t span , gap ;
6568
6669 NBC_IN_PLACE (sendbuf , recvbuf , inplace );
6770
@@ -89,7 +92,9 @@ int ompi_coll_libnbc_ialltoall(const void* sendbuf, int sendcount, MPI_Datatype
8992 /* algorithm selection */
9093 a2asize = sndsize * sendcount * p ;
9194 /* this number is optimized for TCP on odin.cs.indiana.edu */
92- if ((p <= 8 ) && ((a2asize < 1 <<17 ) || (sndsize * sendcount < 1 <<12 ))) {
95+ if (inplace ) {
96+ alg = NBC_A2A_INPLACE ;
97+ } else if ((p <= 8 ) && ((a2asize < 1 <<17 ) || (sndsize * sendcount < 1 <<12 ))) {
9398 /* just send as fast as we can if we have less than 8 peers, if the
9499 * total communicated size is smaller than 1<<17 *and* if we don't
95100 * have eager messages (msgsize < 1<<13) */
@@ -116,7 +121,14 @@ int ompi_coll_libnbc_ialltoall(const void* sendbuf, int sendcount, MPI_Datatype
116121 }
117122
118123 /* allocate temp buffer if we need one */
119- if (alg == NBC_A2A_DISS ) {
124+ if (alg == NBC_A2A_INPLACE ) {
125+ span = opal_datatype_span (& recvtype -> super , recvcount , & gap );
126+ handle -> tmpbuf = malloc (span );
127+ if (OPAL_UNLIKELY (NULL == handle -> tmpbuf )) {
128+ NBC_Return_handle (handle );
129+ return OMPI_ERR_OUT_OF_RESOURCE ;
130+ }
131+ } else if (alg == NBC_A2A_DISS ) {
120132 /* only A2A_DISS needs buffers */
121133 if (NBC_Type_intrinsic (sendtype )) {
122134 datasize = sndext * sendcount ;
@@ -200,6 +212,9 @@ int ompi_coll_libnbc_ialltoall(const void* sendbuf, int sendcount, MPI_Datatype
200212 handle -> schedule = schedule ;
201213
202214 switch (alg ) {
215+ case NBC_A2A_INPLACE :
216+ res = a2a_sched_inplace (rank , p , schedule , recvbuf , recvcount , recvtype , rcvext , gap , comm );
217+ break ;
203218 case NBC_A2A_LINEAR :
204219 res = a2a_sched_linear (rank , p , sndext , rcvext , schedule , sendbuf , sendcount , sendtype , recvbuf , recvcount , recvtype , comm );
205220 break ;
@@ -359,17 +374,10 @@ static inline int a2a_sched_pairwise(int rank, int p, MPI_Aint sndext, MPI_Aint
359374 }
360375
361376 char * sbuf = (char * ) sendbuf + sndpeer * sendcount * sndext ;
362- res = NBC_Sched_send (sbuf , false, sendcount , sendtype , sndpeer , schedule , false );
377+ res = NBC_Sched_send (sbuf , false, sendcount , sendtype , sndpeer , schedule , true );
363378 if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
364379 return res ;
365380 }
366-
367- if (r < p ) {
368- res = NBC_Sched_barrier (schedule );
369- if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
370- return res ;
371- }
372- }
373381 }
374382
375383 return OMPI_SUCCESS ;
@@ -497,3 +505,59 @@ static inline int a2a_sched_diss(int rank, int p, MPI_Aint sndext, MPI_Aint rcve
497505 return OMPI_SUCCESS ;
498506}
499507
508+ static inline int a2a_sched_inplace (int rank , int p , NBC_Schedule * schedule , void * buf , int count ,
509+ MPI_Datatype type , MPI_Aint ext , ptrdiff_t gap , MPI_Comm comm ) {
510+ int res ;
511+
512+ for (int i = 1 ; i < (p + 1 )/2 ; i ++ ) {
513+ int speer = (rank + i ) % p ;
514+ int rpeer = (rank + p - i ) % p ;
515+ char * sbuf = (char * ) buf + speer * count * ext ;
516+ char * rbuf = (char * ) buf + rpeer * count * ext ;
517+
518+ res = NBC_Sched_copy (rbuf , false, count , type ,
519+ (void * )(- gap ), true, count , type ,
520+ schedule , true);
521+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
522+ return res ;
523+ }
524+ res = NBC_Sched_send (sbuf , false , count , type , speer , schedule , false);
525+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
526+ return res ;
527+ }
528+ res = NBC_Sched_recv (rbuf , false , count , type , rpeer , schedule , true);
529+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
530+ return res ;
531+ }
532+
533+ res = NBC_Sched_send ((void * )(- gap ), true, count , type , rpeer , schedule , false);
534+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
535+ return res ;
536+ }
537+ res = NBC_Sched_recv (sbuf , false, count , type , speer , schedule , true);
538+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
539+ return res ;
540+ }
541+ }
542+ if (0 == (p %2 )) {
543+ int peer = (rank + p /2 ) % p ;
544+
545+ char * tbuf = (char * ) buf + peer * count * ext ;
546+ res = NBC_Sched_copy (tbuf , false, count , type ,
547+ (void * )(- gap ), true, count , type ,
548+ schedule , true);
549+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
550+ return res ;
551+ }
552+ res = NBC_Sched_send ((void * )(- gap ), true , count , type , peer , schedule , false);
553+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
554+ return res ;
555+ }
556+ res = NBC_Sched_recv (tbuf , false , count , type , peer , schedule , true);
557+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
558+ return res ;
559+ }
560+ }
561+
562+ return OMPI_SUCCESS ;
563+ }
0 commit comments