@@ -130,12 +130,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
130130        # gather results 
131131        # TODO: _allgather is supposed to be private to DistributedArray 
132132        # but so far, we do not take base_comm_nccl as an argument to Op. 
133-         # For consistency, y._allgather has to be call  here. 
134-         # we can do if  else for  x.base_comm_nccl, but that means 
133+         # For consistency, y._allgather has to be called  here. 
134+         # Alternatively,  we can also  do if- else checking  x.base_comm_nccl, but that means 
135135        # we have to call function from _nccl.py 
136-         # y[:] = np.vstack(y._allgather(y1)).ravel() 
137-         recv  =  y ._allgather (y1 )
138-         y [:] =  recv .ravel ()
136+         y [:] =  ncp .vstack (y ._allgather (y1 )).ravel ()
139137        return  y 
140138
141139    def  _rmatvec (self , x : NDArray ) ->  NDArray :
@@ -172,11 +170,15 @@ def _rmatvec(self, x: NDArray) -> NDArray:
172170                    y1 [isl ] =  ncp .dot (x [isl ].T .conj (), self .G [isl ]).T .conj ()
173171
174172        # gather results 
175-         recv  =  y ._allgather (y1 ) 
176-         if  self .usematmul :
177-             # unrolling like DistributedArray asarray() 
173+         recv  =  y ._allgather (y1 )
174+         # TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI) 
175+         # which returns a list (not flatten) and does not require unrolling 
176+         if  self .usematmul  and  isinstance (recv , ncp .ndarray ) :
177+             # unrolling 
178178            chunk_size  =  self .ny  *  self .nz 
179-             recv  =  ncp .vstack ([recv [i * chunk_size : (i + 1 )* chunk_size ].reshape (self .nz , self .ny ).T  for  i  in  range ((len (recv )+ chunk_size - 1 )// chunk_size )])
180- 
179+             num_partition  =  (len (recv )+ chunk_size - 1 )// chunk_size 
180+             recv  =  ncp .vstack ([recv [i * chunk_size : (i + 1 )* chunk_size ].reshape (self .nz , self .ny ).T  for  i  in  range (num_partition )])
181+         else :
182+             recv  =  ncp .vstack (recv )
181183        y [:] =  recv .ravel ()
182184        return  y 
0 commit comments