Skip to content

Commit d8d9463

Browse files
committed
Initial impl of SUMMA matmul
1 parent 5fcbad3 commit d8d9463

File tree

1 file changed

+29
-42
lines changed

1 file changed

+29
-42
lines changed

examples/matrixmul.py

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,48 @@
1-
import numpy as np
21
from mpi4py import MPI
32
import math
43
import pylops_mpi
4+
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
5+
import numpy as np
56

67
comm = MPI.COMM_WORLD
78
rank = comm.Get_rank()
89
size = comm.Get_size()
910

10-
M = 8 #512
11-
N = 8 #512
12-
K = 8 #512
11+
N = 8
12+
M = 8
13+
K = 8
1314

14-
A_shape = (M,K)
15-
B_shape = (K,N)
16-
C_shape = (M,N)
15+
A_shape = (N, K)
16+
B_shape = (K, M)
17+
C_shape = (N, M)
1718

1819
p_prime = math.isqrt(size)
19-
assert p_prime*p_prime == size, "Number of processes must be a perfect square"
20+
assert p_prime * p_prime == size, "Number of processes must be a perfect square"
2021

21-
# Create A with 2D block-cyclic structure
22-
A_data = np.arange(int(A_shape[0]*A_shape[1])).reshape(A_shape)
23-
A = A_data.reshape(p_prime, M//p_prime, p_prime, K//p_prime).transpose(1, 0, 2, 3).reshape(M//p_prime, -1)
22+
A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
23+
B_data = np.arange(int(B_shape[0] * B_shape[1])).reshape(B_shape)
2424

25-
# Create B with 2D block-cyclic structure
26-
B_data = np.arange(int(B_shape[0]*B_shape[1])).reshape(B_shape)
27-
B = B_data.reshape(p_prime, K//p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(K//p_prime, -1)
25+
N_starts, N_ends = MPIMatrixMult.block_distribute(N, p_prime)
26+
M_starts, M_ends = MPIMatrixMult.block_distribute(M, p_prime)
27+
K_starts, K_ends = MPIMatrixMult.block_distribute(K, p_prime)
2828

29-
A_dist = pylops_mpi.DistributedArray.to_dist(A,
30-
partition=pylops_mpi.Partition.SCATTER,
31-
axis=1)
32-
B_dist = pylops_mpi.DistributedArray.to_dist(B,
33-
partition=pylops_mpi.Partition.SCATTER,
34-
axis=1)
29+
i, j = divmod(rank, p_prime)
30+
A_local = A_data[N_starts[i]:N_ends[i], K_starts[j]:K_ends[j]]
31+
B_local = B_data[K_starts[i]:K_ends[i], M_starts[j]:M_ends[j]]
3532

36-
C_dist = pylops_mpi.DistributedArray(global_shape=(M // p_prime, N * p_prime),
37-
partition=pylops_mpi.Partition.SCATTER,
38-
axis=1)
39-
if rank == 0: print(A_dist.local_array)
33+
B_dist = pylops_mpi.DistributedArray(global_shape=(K*M),
34+
local_shapes=comm.allgather(B_local.shape[0] * B_local.shape[1]),
35+
base_comm=comm,
36+
partition=pylops_mpi.Partition.SCATTER)
37+
B_dist.local_array[:] = B_local.flatten()
4038

41-
i, j = divmod(rank, p_prime)
42-
row_comm = comm.Split(color=i, key=j)
43-
col_comm = comm.Split(color=j, key=i)
44-
45-
c_local = np.zeros((M//p_prime, N//p_prime))
46-
for k in range(p_prime):
47-
Atemp=A_dist.local_array.copy() if j==k else np.empty_like(A_dist.local_array)
48-
Btemp=B_dist.local_array.copy() if i==k else np.empty_like(B_dist.local_array)
49-
rootA=i*p_prime+k; rootB=k*p_prime+j
50-
row_comm.Bcast([Atemp,MPI.FLOAT],root=k)
51-
col_comm.Bcast([Btemp,MPI.FLOAT],root=k)
52-
# print(f"[Rank {rank}] iter{k} after : received A from {rootA}, B from {rootB}, A0={Atemp.flat[0]},B0={Btemp.flat[0]}")
53-
c_local += Atemp @ Btemp
54-
55-
C_dist.local_array[:] = c_local
56-
C_temp = C_dist.asarray().reshape((M,N))
57-
C = C_temp.reshape(M//p_prime, p_prime, p_prime, N//p_prime).transpose(1, 0, 2, 3).reshape(M, N)
39+
print(rank, A_local.shape)
40+
Aop = MPIMatrixMult(A_local, M, base_comm=comm)
41+
C_dist = Aop @ B_dist
42+
C_temp = C_dist.asarray().reshape((N, M))
43+
C = C_temp.reshape(N // p_prime, p_prime, p_prime, M // p_prime).transpose(1, 0, 2, 3).reshape(N, M)
5844

5945
if rank == 0 :
60-
print("expected:\n",A_data @ B_data)
46+
# print("expected:\n",np.allclose(A_data @ B_data, C))
47+
print("expected:\n", A_data @ B_data)
6148
print("calculated:\n",C)

0 commit comments

Comments
 (0)