Skip to content

Commit dc00226

Browse files
committed
added handling for padding
1 parent 4a94ac6 commit dc00226

File tree

2 files changed

+181
-38
lines changed

2 files changed

+181
-38
lines changed

examples/matrixmul.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,30 @@
2424
B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape)
2525

2626
i, j = divmod(rank, p_prime)
27-
2827
A_local, (N_new, K_new) = MPIMatrixMult.block_distribute(A_data, i, j,comm)
2928
B_local, (K_new, M_new) = MPIMatrixMult.block_distribute(B_data, i, j,comm)
3029

31-
B_dist = pylops_mpi.DistributedArray(global_shape=(K_new*M_new),
30+
B_dist = pylops_mpi.DistributedArray(global_shape=(K * M),
3231
local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]),
3332
base_comm=comm,
3433
partition=pylops_mpi.Partition.SCATTER)
3534
B_dist.local_array[:] = B_local.flatten()
3635

37-
Aop = MPIMatrixMult(A_local, M_new, base_comm=comm)
36+
Aop = MPIMatrixMult(A_local, M, base_comm=comm)
3837
C_dist = Aop @ B_dist
3938
Z_dist = Aop.H @ C_dist
4039

41-
C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm)
42-
Z = MPIMatrixMult.block_gather(Z_dist, (K_new,M_new), (K,M), comm)
40+
C = MPIMatrixMult.block_gather(C_dist, (N,M), (N,M), comm)
41+
Z = MPIMatrixMult.block_gather(Z_dist, (K,M), (K,M), comm)
4342
if rank == 0 :
44-
print("expected:\n", np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32)))
45-
# print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj())
46-
# print("calculated:\n",Z.astype(np.int32))
47-
# print("calculated:\n", (A_data.T.dot((A_data @ B_data).conj())).conj() == Z.astype(np.int32))
48-
49-
# print("expected:\n",np.allclose(A_data @ B_data, C))
50-
# print("expected:\n", A_data @ B_data)
51-
# print("calculated:\n",C)
43+
C_correct = np.allclose(A_data @ B_data, C)
44+
print("C expected: ", C_correct)
45+
if not C_correct:
46+
print("expected:\n", A_data @ B_data)
47+
print("calculated:\n",C)
48+
49+
Z_correct = np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32))
50+
print("Z expected: ", Z_correct)
51+
if not Z_correct:
52+
print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj())
53+
print("calculated:\n", Z.astype(np.int32))

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 166 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,32 @@ def __init__(
136136
self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id)
137137

138138
self.A = A.astype(np.dtype(dtype))
139-
if saveAt:
140-
self.At = A.T.conj()
141139

142140
self.N = self._col_comm.allreduce(A.shape[0])
143141
self.K = self._row_comm.allreduce(A.shape[1])
144142
self.M = M
145143

144+
self._N_padded = math.ceil(self.N / self._P_prime) * self._P_prime
145+
self._K_padded = math.ceil(self.K / self._P_prime) * self._P_prime
146+
self._M_padded = math.ceil(self.M / self._P_prime) * self._P_prime
147+
148+
bn = self._N_padded // self._P_prime
149+
bk = self._K_padded // self._P_prime
150+
bm = self._M_padded // self._P_prime
151+
152+
pr = (bn - A.shape[0]) if self._row_id == self._P_prime - 1 else 0
153+
pc = (bk - A.shape[1]) if self._col_id == self._P_prime - 1 else 0
154+
155+
if pr < 0 or pc < 0:
156+
raise Exception(f"Improper distribution of A expected local shape "
157+
f"( ≤ {bn}, ≤ {bk}) but got ({A.shape[0]},{A.shape[1]})")
158+
159+
if pr > 0 or pc > 0:
160+
self.A = np.pad(self.A, [(0, pr), (0, pc)], mode='constant')
161+
162+
if saveAt:
163+
self.At = self.A.T.conj()
164+
146165
self.dims = (self.K, self.M)
147166
self.dimsd = (self.N, self.M)
148167
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
@@ -218,65 +237,185 @@ def block_distribute(array, proc_i, proc_j, comm):
218237
i0, j0 = proc_i * br, proc_j * bc
219238
i1, j1 = min(i0 + br, orig_r), min(j0 + bc, orig_c)
220239

221-
block = array[i0:i1, j0:j1]
240+
i_end = None if proc_i == p_prime - 1 else i1
241+
j_end = None if proc_j == p_prime - 1 else j1
242+
block = array[i0:i_end, j0:j_end]
243+
222244
pr = (new_r - orig_r) if proc_i == p_prime - 1 else 0
223245
pc = (new_c - orig_c) if proc_j == p_prime - 1 else 0
224-
if pr or pc:
225-
block = np.pad(block, [(0, pr), (0, pc)], mode='constant')
226-
246+
#comment the padding to get the block as unpadded
247+
# if pr or pc: block = np.pad(block, [(0, pr), (0, pc)], mode='constant')
227248
return block, (new_r, new_c)
228249

229250
@staticmethod
230251
def block_gather(x, new_shape, orig_shape, comm):
231252
ncp = get_module(x.engine)
232253
p_prime = math.isqrt(comm.Get_size())
233254
all_blks = comm.allgather(x.local_array)
234-
nr, nc = new_shape
255+
256+
nr, nc = new_shape
235257
orr, orc = orig_shape
236-
br, bc = nr // p_prime, nc // p_prime
237-
C = ncp.array(all_blks).reshape(p_prime, p_prime, br, bc).transpose(0, 2, 1, 3).reshape(nr, nc)
258+
259+
# Calculate base block sizes
260+
br_base = nr // p_prime
261+
bc_base = nc // p_prime
262+
263+
# Calculate remainder rows/cols that need to be distributed
264+
r_remainder = nr % p_prime
265+
c_remainder = nc % p_prime
266+
267+
# Create the output matrix
268+
C = ncp.zeros((nr, nc), dtype=all_blks[0].dtype)
269+
270+
# Place each block in the correct position
271+
for rank in range(p_prime * p_prime):
272+
# Convert linear rank to 2D grid position
273+
proc_row = rank // p_prime
274+
proc_col = rank % p_prime
275+
276+
# Calculate this process's block dimensions
277+
block_rows = br_base + (1 if proc_row < r_remainder else 0)
278+
block_cols = bc_base + (1 if proc_col < c_remainder else 0)
279+
280+
# Calculate starting position in global matrix
281+
start_row = proc_row * br_base + min(proc_row, r_remainder)
282+
start_col = proc_col * bc_base + min(proc_col, c_remainder)
283+
284+
# Place the block
285+
block = all_blks[rank]
286+
if block.ndim == 1:
287+
block = block.reshape(block_rows, block_cols)
288+
C[start_row:start_row + block_rows, start_col:start_col + block_cols] = block
238289
return C[:orr, :orc]
239290

240291
def _matvec(self, x: DistributedArray) -> DistributedArray:
241292
ncp = get_module(x.engine)
242293
if x.partition != Partition.SCATTER:
243294
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
244-
local_shape = ((self.N * self.M) // self.size)
245-
y = DistributedArray(global_shape=(self.N * self.M),
295+
296+
# Calculate local shapes for block distribution
297+
bn = self._N_padded // self._P_prime # block size in N dimension
298+
bm = self._M_padded // self._P_prime # block size in M dimension
299+
300+
# Calculate actual local shape for this process (considering original dimensions)
301+
local_n = bn
302+
local_m = bm
303+
304+
# Adjust for edge/corner processes that might have smaller blocks
305+
if self._row_id == self._P_prime - 1:
306+
local_n = self.N - (self._P_prime - 1) * bn
307+
if self._col_id == self._P_prime - 1:
308+
local_m = self.M - (self._P_prime - 1) * bm
309+
310+
local_shape = local_n * local_m
311+
312+
# Create local_shapes array for all processes
313+
local_shapes = []
314+
for rank in range(self.size):
315+
row_id, col_id = divmod(rank, self._P_prime)
316+
proc_n = bn if row_id != self._P_prime - 1 else self.N - (self._P_prime - 1) * bn
317+
proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
318+
local_shapes.append(proc_n * proc_m)
319+
320+
y = DistributedArray(global_shape=(self.N * self.M),
246321
mask=x.mask,
247-
local_shapes=[local_shape] * self.size,
322+
local_shapes=local_shapes,
248323
partition=Partition.SCATTER,
249-
dtype=self.dtype)
324+
dtype=self.dtype,
325+
base_comm=self.base_comm
326+
)
327+
328+
# Calculate expected padded dimensions for x
329+
bk = self._K_padded // self._P_prime # block size in K dimension
330+
331+
# The input x corresponds to blocks from matrix B (K x M)
332+
# This process should receive a block of size (local_k x local_m)
333+
local_k = bk
334+
if self._row_id == self._P_prime - 1:
335+
local_k = self.K - (self._P_prime - 1) * bk
336+
337+
# Reshape x.local_array to its 2D block form
338+
x_block = x.local_array.reshape((local_k, local_m))
339+
340+
# Pad the block to the full padded size if necessary
341+
pad_k = bk - local_k
342+
pad_m = bm - local_m
343+
344+
if pad_k > 0 or pad_m > 0:
345+
x_block = np.pad(x_block, [(0, pad_k), (0, pad_m)], mode='constant')
346+
347+
Y_local = np.zeros((self.A.shape[0], bm))
250348

251-
x = x.local_array.reshape((self.A.shape[1], -1))
252-
Y_local = np.zeros((self.A.shape[0], x.shape[1]))
253349
for k in range(self._P_prime):
254350
Atemp = self.A.copy() if self._col_id == k else np.empty_like(self.A)
255-
Xtemp = x.copy() if self._row_id == k else np.empty_like(x)
351+
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
256352
self._row_comm.Bcast(Atemp, root=k)
257353
self._col_comm.Bcast(Xtemp, root=k)
258354
Y_local += ncp.dot(Atemp, Xtemp)
259-
y[:] = Y_local.flatten()
260-
return y
261355

356+
Y_local_unpadded = Y_local[:local_n, :local_m]
357+
y[:] = Y_local_unpadded.flatten()
358+
return y
262359

263360
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
264361
ncp = get_module(x.engine)
265362
if x.partition != Partition.SCATTER:
266363
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
267364

268-
local_shape = ((self.K * self.M ) // self.size)
365+
# Calculate local shapes for block distribution
366+
bk = self._K_padded // self._P_prime # block size in K dimension
367+
bm = self._M_padded // self._P_prime # block size in M dimension
368+
369+
# Calculate actual local shape for this process (considering original dimensions)
370+
local_k = bk
371+
local_m = bm
372+
373+
# Adjust for edge/corner processes that might have smaller blocks
374+
if self._row_id == self._P_prime - 1:
375+
local_k = self.K - (self._P_prime - 1) * bk
376+
if self._col_id == self._P_prime - 1:
377+
local_m = self.M - (self._P_prime - 1) * bm
378+
379+
local_shape = local_k * local_m
380+
381+
# Create local_shapes array for all processes
382+
local_shapes = []
383+
for rank in range(self.size):
384+
row_id, col_id = divmod(rank, self._P_prime)
385+
proc_k = bk if row_id != self._P_prime - 1 else self.K - (self._P_prime - 1) * bk
386+
proc_m = bm if col_id != self._P_prime - 1 else self.M - (self._P_prime - 1) * bm
387+
local_shapes.append(proc_k * proc_m)
388+
269389
y = DistributedArray(
270390
global_shape=(self.K * self.M),
271391
mask=x.mask,
272-
local_shapes=[local_shape] * self.size,
392+
local_shapes=local_shapes,
273393
partition=Partition.SCATTER,
274394
dtype=self.dtype,
275395
base_comm=self.base_comm
276396
)
277-
x_reshaped = x.local_array.reshape((self.A.shape[0], -1))
397+
398+
# Calculate expected padded dimensions for x
399+
bn = self._N_padded // self._P_prime # block size in N dimension
400+
401+
# The input x corresponds to blocks from the result (N x M)
402+
# This process should receive a block of size (local_n x local_m)
403+
local_n = bn
404+
if self._row_id == self._P_prime - 1:
405+
local_n = self.N - (self._P_prime - 1) * bn
406+
407+
# Reshape x.local_array to its 2D block form
408+
x_block = x.local_array.reshape((local_n, local_m))
409+
410+
# Pad the block to the full padded size if necessary
411+
pad_n = bn - local_n
412+
pad_m = bm - local_m
413+
414+
if pad_n > 0 or pad_m > 0:
415+
x_block = np.pad(x_block, [(0, pad_n), (0, pad_m)], mode='constant')
416+
278417
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
279-
Y_local = np.zeros((self.A.shape[1], x_reshaped.shape[1]))
418+
Y_local = np.zeros((self.A.shape[1], bm))
280419

281420
for k in range(self._P_prime):
282421
requests = []
@@ -289,10 +428,12 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
289428
for moving_col in range(self._P_prime):
290429
destA = fixed_col * self._P_prime + moving_col
291430
tagA = (100 + k) * 1000 + destA
292-
requests.append(self.base_comm.Isend(A_local, dest=destA,tag=tagA))
293-
Xtemp = x_reshaped.copy() if self._row_id == k else np.empty_like(x_reshaped)
431+
requests.append(self.base_comm.Isend(A_local, dest=destA, tag=tagA))
432+
Xtemp = x_block.copy() if self._row_id == k else np.empty_like(x_block)
294433
requests.append(self._col_comm.Ibcast(Xtemp, root=k))
295434
MPI.Request.Waitall(requests)
296435
Y_local += ncp.dot(ATtemp, Xtemp)
297-
y[:] = Y_local.flatten()
436+
437+
Y_local_unpadded = Y_local[:local_k, :local_m]
438+
y[:] = Y_local_unpadded.flatten()
298439
return y

0 commit comments

Comments
 (0)