@@ -130,12 +130,10 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
130130 # gather results
131131 # TODO: _allgather is supposed to be private to DistributedArray
132132 # but so far, we do not take base_comm_nccl as an argument to Op.
133- # For consistency, y._allgather has to be call here.
134- # we can do if else for x.base_comm_nccl, but that means
133+ # For consistency, y._allgather has to be called here.
134+ # Alternatively, we can also do if- else checking x.base_comm_nccl, but that means
135135 # we have to call function from _nccl.py
136- # y[:] = np.vstack(y._allgather(y1)).ravel()
137- recv = y ._allgather (y1 )
138- y [:] = recv .ravel ()
136+ y [:] = ncp .vstack (y ._allgather (y1 )).ravel ()
139137 return y
140138
141139 def _rmatvec (self , x : NDArray ) -> NDArray :
@@ -172,11 +170,15 @@ def _rmatvec(self, x: NDArray) -> NDArray:
172170 y1 [isl ] = ncp .dot (x [isl ].T .conj (), self .G [isl ]).T .conj ()
173171
174172 # gather results
175- recv = y ._allgather (y1 )
176- if self .usematmul :
177- # unrolling like DistributedArray asarray()
173+ recv = y ._allgather (y1 )
174+ # TODO: current of _allgather will call non-buffered MPI-AllGather (sub-optimal for CuPy+MPI)
175+ # which returns a list (not flatten) and does not require unrolling
176+ if self .usematmul and isinstance (recv , ncp .ndarray ) :
177+ # unrolling
178178 chunk_size = self .ny * self .nz
179- recv = ncp .vstack ([recv [i * chunk_size : (i + 1 )* chunk_size ].reshape (self .nz , self .ny ).T for i in range ((len (recv )+ chunk_size - 1 )// chunk_size )])
180-
179+ num_partition = (len (recv )+ chunk_size - 1 )// chunk_size
180+ recv = ncp .vstack ([recv [i * chunk_size : (i + 1 )* chunk_size ].reshape (self .nz , self .ny ).T for i in range (num_partition )])
181+ else :
182+ recv = ncp .vstack (recv )
181183 y [:] = recv .ravel ()
182184 return y
0 commit comments