@@ -103,10 +103,6 @@ struct ompi_comm_allreduce_context_t {
103103    ompi_comm_cid_context_t  * cid_context ;
104104    int  * tmpbuf ;
105105
106-     /* for intercomm allreduce */ 
107-     int  * rcounts ;
108-     int  * rdisps ;
109- 
110106    /* for group allreduce */ 
111107    int  peers_comm [3 ];
112108};
@@ -121,8 +117,6 @@ static void ompi_comm_allreduce_context_construct (ompi_comm_allreduce_context_t
121117static  void  ompi_comm_allreduce_context_destruct  (ompi_comm_allreduce_context_t  * context )
122118{
123119    free  (context -> tmpbuf );
124-     free  (context -> rcounts );
125-     free  (context -> rdisps );
126120}
127121
128122OBJ_CLASS_INSTANCE  (ompi_comm_allreduce_context_t , opal_object_t ,
@@ -602,7 +596,7 @@ static int ompi_comm_allreduce_intra_nb (int *inbuf, int *outbuf, int count, str
602596/* Non-blocking version of ompi_comm_allreduce_inter */ 
603597static  int  ompi_comm_allreduce_inter_leader_exchange  (ompi_comm_request_t  * request );
604598static  int  ompi_comm_allreduce_inter_leader_reduce  (ompi_comm_request_t  * request );
605- static  int  ompi_comm_allreduce_inter_allgather  (ompi_comm_request_t  * request );
599+ static  int  ompi_comm_allreduce_inter_bcast  (ompi_comm_request_t  * request );
606600
607601static  int  ompi_comm_allreduce_inter_nb  (int  * inbuf , int  * outbuf ,
608602                                         int  count , struct  ompi_op_t  * op ,
@@ -636,18 +630,19 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
636630    rsize       =  ompi_comm_remote_size  (intercomm );
637631    local_rank  =  ompi_comm_rank  (intercomm );
638632
639-     context -> tmpbuf    =  ( int   * )  calloc  ( count ,  sizeof ( int )); 
640-     context -> rdisps   =  (int  * ) calloc  (rsize , sizeof (int ));
641-     context -> rcounts   =  ( int   * )  calloc  ( rsize ,  sizeof ( int )); 
642-     if  ( OPAL_UNLIKELY  ( NULL   ==   context -> tmpbuf   ||   NULL   ==   context -> rdisps   ||   NULL   ==   context -> rcounts )) { 
643-         ompi_comm_request_return  ( request ) ;
644-         return   OMPI_ERR_OUT_OF_RESOURCE ; 
633+     if  ( 0   ==   local_rank ) { 
634+          context -> tmpbuf   =  (int  * ) calloc  (count , sizeof (int ));
635+          if  ( OPAL_UNLIKELY  ( NULL   ==   context -> tmpbuf )) { 
636+              ompi_comm_request_return  ( request ); 
637+              return   OMPI_ERR_OUT_OF_RESOURCE ;
638+         } 
645639    }
646640
647641    /* Execute the inter-allreduce: the result from the local will be in the buffer of the remote group 
648642     * and vise-versa. */ 
649-     rc  =  intercomm -> c_coll .coll_iallreduce  (inbuf , context -> tmpbuf , count , MPI_INT , op , intercomm ,
650-                                             & subreq , intercomm -> c_coll .coll_iallreduce_module );
643+     rc  =  intercomm -> c_local_comm -> c_coll .coll_ireduce  (inbuf , context -> tmpbuf , count , MPI_INT , op , 0 ,
644+                                                        intercomm -> c_local_comm , & subreq ,
645+                                                        intercomm -> c_local_comm -> c_coll .coll_ireduce_module );
651646    if  (OPAL_UNLIKELY (OMPI_SUCCESS  !=  rc )) {
652647        ompi_comm_request_return  (request );
653648        return  rc ;
@@ -656,7 +651,7 @@ static int ompi_comm_allreduce_inter_nb (int *inbuf, int *outbuf,
656651    if  (0  ==  local_rank ) {
657652        ompi_comm_request_schedule_append  (request , ompi_comm_allreduce_inter_leader_exchange , & subreq , 1 );
658653    } else  {
659-         ompi_comm_request_schedule_append  (request , ompi_comm_allreduce_inter_allgather , & subreq , 1 );
654+         ompi_comm_request_schedule_append  (request , ompi_comm_allreduce_inter_bcast , & subreq , 1 );
660655    }
661656
662657    ompi_comm_request_start  (request );
@@ -696,33 +691,20 @@ static int ompi_comm_allreduce_inter_leader_reduce (ompi_comm_request_t *request
696691
697692    ompi_op_reduce  (context -> op , context -> tmpbuf , context -> outbuf , context -> count , MPI_INT );
698693
699-     return  ompi_comm_allreduce_inter_allgather  (request );
694+     return  ompi_comm_allreduce_inter_bcast  (request );
700695}
701696
702697
703- static  int  ompi_comm_allreduce_inter_allgather  (ompi_comm_request_t  * request )
698+ static  int  ompi_comm_allreduce_inter_bcast  (ompi_comm_request_t  * request )
704699{
705700    ompi_comm_allreduce_context_t  * context  =  (ompi_comm_allreduce_context_t  * ) request -> context ;
706-     ompi_communicator_t  * intercomm  =  context -> cid_context -> comm ;
701+     ompi_communicator_t  * comm  =  context -> cid_context -> comm -> c_local_comm ;
707702    ompi_request_t  * subreq ;
708703    int  scount  =  0 , rc ;
709704
710-     /* distribute the overall result to all processes in the other group. 
711-        Instead of using bcast, we are using here allgatherv, to avoid the 
712-        possible deadlock. Else, we need an algorithm to determine, 
713-        which group sends first in the inter-bcast and which receives 
714-        the result first. 
715-     */ 
716- 
717-     if  (0  !=  ompi_comm_rank  (intercomm )) {
718-         context -> rcounts [0 ] =  context -> count ;
719-     } else  {
720-         scount  =  context -> count ;
721-     }
722- 
723-     rc  =  intercomm -> c_coll .coll_iallgatherv  (context -> outbuf , scount , MPI_INT , context -> outbuf ,
724-                                              context -> rcounts , context -> rdisps , MPI_INT , intercomm ,
725-                                              & subreq , intercomm -> c_coll .coll_iallgatherv_module );
705+     /* both roots have the same result. broadcast to the local group */ 
706+     rc  =  comm -> c_coll .coll_ibcast  (context -> outbuf , context -> count , MPI_INT , 0 , comm ,
707+                                    & subreq , comm -> c_coll .coll_ibcast_module );
726708    if  (OMPI_SUCCESS  !=  rc ) {
727709        return  rc ;
728710    }
0 commit comments