@@ -246,22 +246,22 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
246246 ncp = get_module (x .engine )
247247 if x .partition != Partition .SCATTER :
248248 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
249- local_shape = (self .N // self . _P_prime ) * ( self .M * self . _P_prime // self .size )
250- y = DistributedArray (global_shape = (( self .N // self . _P_prime ) * self .M * self . _P_prime ),
249+ local_shape = (( self .N * self .M ) // self .size )
250+ y = DistributedArray (global_shape = (self .N * self .M ),
251251 mask = x .mask ,
252- local_shapes = [ local_shape for _ in range ( self .size )] ,
252+ local_shapes = [local_shape ] * self .size ,
253253 partition = Partition .SCATTER ,
254254 dtype = self .dtype )
255255
256256 x = x .local_array .reshape ((self .A .shape [1 ], - 1 ))
257- c_local = np .zeros ((self .A .shape [0 ], x .shape [1 ]))
257+ Y_local = np .zeros ((self .A .shape [0 ], x .shape [1 ]))
258258 for k in range (self ._P_prime ):
259259 Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
260260 Xtemp = x .copy () if self ._row_id == k else np .empty_like (x )
261261 self ._row_comm .Bcast (Atemp , root = k )
262262 self ._col_comm .Bcast (Xtemp , root = k )
263- c_local += ncp .dot (Atemp , Xtemp )
264- y [:] = c_local .flatten ()
263+ Y_local += ncp .dot (Atemp , Xtemp )
264+ y [:] = Y_local .flatten ()
265265 return y
266266
267267
@@ -270,38 +270,33 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
270270 if x .partition != Partition .SCATTER :
271271 raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
272272
273- local_shape = (self .K // self . _P_prime ) * ( self .M * self . _P_prime // self .size )
273+ local_shape = (( self .K * self .M ) // self .size )
274274 y = DistributedArray (
275- global_shape = (( self .K // self . _P_prime ) * self .M * self . _P_prime ),
275+ global_shape = (self .K * self .M ),
276276 mask = x .mask ,
277- local_shapes = [local_shape for _ in range ( self .size )] ,
277+ local_shapes = [local_shape ] * self .size ,
278278 partition = Partition .SCATTER ,
279279 dtype = self .dtype ,
280280 )
281281 x_reshaped = x .local_array .reshape ((self .A .shape [0 ], - 1 ))
282282 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
283- c_local = np .zeros ((self .A .shape [1 ], x_reshaped .shape [1 ]))
284- P = self ._P_prime
283+ Y_local = np .zeros ((self .A .shape [1 ], x_reshaped .shape [1 ]))
285284
286- for k in range (P ):
287- temps = {}
285+ for k in range (self ._P_prime ):
288286 requests = []
289- for buf , owner , base , name in (
290- (A_local , self ._row_id , 100 , 'A' ),
291- (x_reshaped , self ._col_id , 200 , 'B' ),
292- ):
293- tmp = np .empty_like (buf )
294- temps [name ] = tmp
295- src , tag = k * P + owner , (base + k ) * 1000 + self .rank
296- requests .append (self .base_comm .Irecv (tmp , source = src , tag = tag ))
297-
298- if self .rank // P == k :
299- fixed = self .rank % P
300- for moving in range (P ):
301- dest = (fixed * P + moving ) if name == 'A' else moving * P + fixed
302- tag = (base + k ) * 1000 + dest
303- requests .append (self .base_comm .Isend (buf , dest = dest , tag = tag ))
287+ ATtemp = np .empty_like (A_local )
288+ srcA = k * self ._P_prime + self ._row_id
289+ tagA = (100 + k ) * 1000 + self .rank
290+ requests .append (self .base_comm .Irecv (ATtemp , source = srcA , tag = tagA ))
291+ if self ._row_id == k :
292+ fixed_col = self ._col_id
293+ for moving_col in range (self ._P_prime ):
294+ destA = fixed_col * self ._P_prime + moving_col
295+ tagA = (100 + k ) * 1000 + destA
296+ requests .append (self .base_comm .Isend (A_local , dest = destA ,tag = tagA ))
297+ Xtemp = x_reshaped .copy () if self ._row_id == k else np .empty_like (x_reshaped )
298+ requests .append (self ._col_comm .Ibcast (Xtemp , root = k ))
304299 MPI .Request .Waitall (requests )
305- c_local += ncp .dot (temps [ 'A' ], temps [ 'B' ] )
306- y [:] = c_local .flatten ()
300+ Y_local += ncp .dot (ATtemp , Xtemp )
301+ y [:] = Y_local .flatten ()
307302 return y
0 commit comments