@@ -136,13 +136,32 @@ def __init__(
136136 self ._col_comm = base_comm .Split (color = self ._col_id , key = self ._row_id )
137137
138138 self .A = A .astype (np .dtype (dtype ))
139- if saveAt :
140- self .At = A .T .conj ()
141139
142140 self .N = self ._col_comm .allreduce (A .shape [0 ])
143141 self .K = self ._row_comm .allreduce (A .shape [1 ])
144142 self .M = M
145143
144+ self ._N_padded = math .ceil (self .N / self ._P_prime ) * self ._P_prime
145+ self ._K_padded = math .ceil (self .K / self ._P_prime ) * self ._P_prime
146+ self ._M_padded = math .ceil (self .M / self ._P_prime ) * self ._P_prime
147+
148+ bn = self ._N_padded // self ._P_prime
149+ bk = self ._K_padded // self ._P_prime
150+ bm = self ._M_padded // self ._P_prime
151+
152+ pr = (bn - A .shape [0 ]) if self ._row_id == self ._P_prime - 1 else 0
153+ pc = (bk - A .shape [1 ]) if self ._col_id == self ._P_prime - 1 else 0
154+
155+ if pr < 0 or pc < 0 :
156+ raise Exception (f"Improper distribution of A expected local shape "
157+ f"( ≤ { bn } , ≤ { bk } ) but got ({ A .shape [0 ]} ,{ A .shape [1 ]} )" )
158+
159+ if pr > 0 or pc > 0 :
160+ self .A = np .pad (self .A , [(0 , pr ), (0 , pc )], mode = 'constant' )
161+
162+ if saveAt :
163+ self .At = self .A .T .conj ()
164+
146165 self .dims = (self .K , self .M )
147166 self .dimsd = (self .N , self .M )
148167 shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
@@ -218,65 +237,185 @@ def block_distribute(array, proc_i, proc_j, comm):
218237 i0 , j0 = proc_i * br , proc_j * bc
219238 i1 , j1 = min (i0 + br , orig_r ), min (j0 + bc , orig_c )
220239
221- block = array [i0 :i1 , j0 :j1 ]
240+ i_end = None if proc_i == p_prime - 1 else i1
241+ j_end = None if proc_j == p_prime - 1 else j1
242+ block = array [i0 :i_end , j0 :j_end ]
243+
222244 pr = (new_r - orig_r ) if proc_i == p_prime - 1 else 0
223245 pc = (new_c - orig_c ) if proc_j == p_prime - 1 else 0
224- if pr or pc :
225- block = np .pad (block , [(0 , pr ), (0 , pc )], mode = 'constant' )
226-
246+ #comment the padding to get the block as unpadded
247+ # if pr or pc: block = np.pad(block, [(0, pr), (0, pc)], mode='constant')
227248 return block , (new_r , new_c )
228249
229250 @staticmethod
230251 def block_gather (x , new_shape , orig_shape , comm ):
231252 ncp = get_module (x .engine )
232253 p_prime = math .isqrt (comm .Get_size ())
233254 all_blks = comm .allgather (x .local_array )
234- nr , nc = new_shape
255+
256+ nr , nc = new_shape
235257 orr , orc = orig_shape
236- br , bc = nr // p_prime , nc // p_prime
237- C = ncp .array (all_blks ).reshape (p_prime , p_prime , br , bc ).transpose (0 , 2 , 1 , 3 ).reshape (nr , nc )
258+
259+ # Calculate base block sizes
260+ br_base = nr // p_prime
261+ bc_base = nc // p_prime
262+
263+ # Calculate remainder rows/cols that need to be distributed
264+ r_remainder = nr % p_prime
265+ c_remainder = nc % p_prime
266+
267+ # Create the output matrix
268+ C = ncp .zeros ((nr , nc ), dtype = all_blks [0 ].dtype )
269+
270+ # Place each block in the correct position
271+ for rank in range (p_prime * p_prime ):
272+ # Convert linear rank to 2D grid position
273+ proc_row = rank // p_prime
274+ proc_col = rank % p_prime
275+
276+ # Calculate this process's block dimensions
277+ block_rows = br_base + (1 if proc_row < r_remainder else 0 )
278+ block_cols = bc_base + (1 if proc_col < c_remainder else 0 )
279+
280+ # Calculate starting position in global matrix
281+ start_row = proc_row * br_base + min (proc_row , r_remainder )
282+ start_col = proc_col * bc_base + min (proc_col , c_remainder )
283+
284+ # Place the block
285+ block = all_blks [rank ]
286+ if block .ndim == 1 :
287+ block = block .reshape (block_rows , block_cols )
288+ C [start_row :start_row + block_rows , start_col :start_col + block_cols ] = block
238289 return C [:orr , :orc ]
239290
240291 def _matvec (self , x : DistributedArray ) -> DistributedArray :
241292 ncp = get_module (x .engine )
242293 if x .partition != Partition .SCATTER :
243294 raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
244- local_shape = ((self .N * self .M ) // self .size )
245- y = DistributedArray (global_shape = (self .N * self .M ),
295+
296+ # Calculate local shapes for block distribution
297+ bn = self ._N_padded // self ._P_prime # block size in N dimension
298+ bm = self ._M_padded // self ._P_prime # block size in M dimension
299+
300+ # Calculate actual local shape for this process (considering original dimensions)
301+ local_n = bn
302+ local_m = bm
303+
304+ # Adjust for edge/corner processes that might have smaller blocks
305+ if self ._row_id == self ._P_prime - 1 :
306+ local_n = self .N - (self ._P_prime - 1 ) * bn
307+ if self ._col_id == self ._P_prime - 1 :
308+ local_m = self .M - (self ._P_prime - 1 ) * bm
309+
310+ local_shape = local_n * local_m
311+
312+ # Create local_shapes array for all processes
313+ local_shapes = []
314+ for rank in range (self .size ):
315+ row_id , col_id = divmod (rank , self ._P_prime )
316+ proc_n = bn if row_id != self ._P_prime - 1 else self .N - (self ._P_prime - 1 ) * bn
317+ proc_m = bm if col_id != self ._P_prime - 1 else self .M - (self ._P_prime - 1 ) * bm
318+ local_shapes .append (proc_n * proc_m )
319+
320+ y = DistributedArray (global_shape = (self .N * self .M ),
246321 mask = x .mask ,
247- local_shapes = [ local_shape ] * self . size ,
322+ local_shapes = local_shapes ,
248323 partition = Partition .SCATTER ,
249- dtype = self .dtype )
324+ dtype = self .dtype ,
325+ base_comm = self .base_comm
326+ )
327+
328+ # Calculate expected padded dimensions for x
329+ bk = self ._K_padded // self ._P_prime # block size in K dimension
330+
331+ # The input x corresponds to blocks from matrix B (K x M)
332+ # This process should receive a block of size (local_k x local_m)
333+ local_k = bk
334+ if self ._row_id == self ._P_prime - 1 :
335+ local_k = self .K - (self ._P_prime - 1 ) * bk
336+
337+ # Reshape x.local_array to its 2D block form
338+ x_block = x .local_array .reshape ((local_k , local_m ))
339+
340+ # Pad the block to the full padded size if necessary
341+ pad_k = bk - local_k
342+ pad_m = bm - local_m
343+
344+ if pad_k > 0 or pad_m > 0 :
345+ x_block = np .pad (x_block , [(0 , pad_k ), (0 , pad_m )], mode = 'constant' )
346+
347+ Y_local = np .zeros ((self .A .shape [0 ], bm ))
250348
251- x = x .local_array .reshape ((self .A .shape [1 ], - 1 ))
252- Y_local = np .zeros ((self .A .shape [0 ], x .shape [1 ]))
253349 for k in range (self ._P_prime ):
254350 Atemp = self .A .copy () if self ._col_id == k else np .empty_like (self .A )
255- Xtemp = x .copy () if self ._row_id == k else np .empty_like (x )
351+ Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
256352 self ._row_comm .Bcast (Atemp , root = k )
257353 self ._col_comm .Bcast (Xtemp , root = k )
258354 Y_local += ncp .dot (Atemp , Xtemp )
259- y [:] = Y_local .flatten ()
260- return y
261355
356+ Y_local_unpadded = Y_local [:local_n , :local_m ]
357+ y [:] = Y_local_unpadded .flatten ()
358+ return y
262359
263360 def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
264361 ncp = get_module (x .engine )
265362 if x .partition != Partition .SCATTER :
266363 raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
267364
268- local_shape = ((self .K * self .M ) // self .size )
365+ # Calculate local shapes for block distribution
366+ bk = self ._K_padded // self ._P_prime # block size in K dimension
367+ bm = self ._M_padded // self ._P_prime # block size in M dimension
368+
369+ # Calculate actual local shape for this process (considering original dimensions)
370+ local_k = bk
371+ local_m = bm
372+
373+ # Adjust for edge/corner processes that might have smaller blocks
374+ if self ._row_id == self ._P_prime - 1 :
375+ local_k = self .K - (self ._P_prime - 1 ) * bk
376+ if self ._col_id == self ._P_prime - 1 :
377+ local_m = self .M - (self ._P_prime - 1 ) * bm
378+
379+ local_shape = local_k * local_m
380+
381+ # Create local_shapes array for all processes
382+ local_shapes = []
383+ for rank in range (self .size ):
384+ row_id , col_id = divmod (rank , self ._P_prime )
385+ proc_k = bk if row_id != self ._P_prime - 1 else self .K - (self ._P_prime - 1 ) * bk
386+ proc_m = bm if col_id != self ._P_prime - 1 else self .M - (self ._P_prime - 1 ) * bm
387+ local_shapes .append (proc_k * proc_m )
388+
269389 y = DistributedArray (
270390 global_shape = (self .K * self .M ),
271391 mask = x .mask ,
272- local_shapes = [ local_shape ] * self . size ,
392+ local_shapes = local_shapes ,
273393 partition = Partition .SCATTER ,
274394 dtype = self .dtype ,
275395 base_comm = self .base_comm
276396 )
277- x_reshaped = x .local_array .reshape ((self .A .shape [0 ], - 1 ))
397+
398+ # Calculate expected padded dimensions for x
399+ bn = self ._N_padded // self ._P_prime # block size in N dimension
400+
401+ # The input x corresponds to blocks from the result (N x M)
402+ # This process should receive a block of size (local_n x local_m)
403+ local_n = bn
404+ if self ._row_id == self ._P_prime - 1 :
405+ local_n = self .N - (self ._P_prime - 1 ) * bn
406+
407+ # Reshape x.local_array to its 2D block form
408+ x_block = x .local_array .reshape ((local_n , local_m ))
409+
410+ # Pad the block to the full padded size if necessary
411+ pad_n = bn - local_n
412+ pad_m = bm - local_m
413+
414+ if pad_n > 0 or pad_m > 0 :
415+ x_block = np .pad (x_block , [(0 , pad_n ), (0 , pad_m )], mode = 'constant' )
416+
278417 A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
279- Y_local = np .zeros ((self .A .shape [1 ], x_reshaped . shape [ 1 ] ))
418+ Y_local = np .zeros ((self .A .shape [1 ], bm ))
280419
281420 for k in range (self ._P_prime ):
282421 requests = []
@@ -289,10 +428,12 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
289428 for moving_col in range (self ._P_prime ):
290429 destA = fixed_col * self ._P_prime + moving_col
291430 tagA = (100 + k ) * 1000 + destA
292- requests .append (self .base_comm .Isend (A_local , dest = destA ,tag = tagA ))
293- Xtemp = x_reshaped .copy () if self ._row_id == k else np .empty_like (x_reshaped )
431+ requests .append (self .base_comm .Isend (A_local , dest = destA , tag = tagA ))
432+ Xtemp = x_block .copy () if self ._row_id == k else np .empty_like (x_block )
294433 requests .append (self ._col_comm .Ibcast (Xtemp , root = k ))
295434 MPI .Request .Waitall (requests )
296435 Y_local += ncp .dot (ATtemp , Xtemp )
297- y [:] = Y_local .flatten ()
436+
437+ Y_local_unpadded = Y_local [:local_k , :local_m ]
438+ y [:] = Y_local_unpadded .flatten ()
298439 return y
0 commit comments