Skip to content

Commit 7fcc2cf

Browse files
committed
Added impl, test and example
1 parent ec4b8c1 commit 7fcc2cf

File tree

3 files changed

+437
-0
lines changed

3 files changed

+437
-0
lines changed

examples/matrixmul.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import sys
2+
import math
3+
import numpy as np
4+
from mpi4py import MPI
5+
6+
from pylops_mpi import DistributedArray, Partition
7+
from pylops_mpi.basicoperators.MatrixMultiply import SUMMAMatrixMult
8+
9+
np.random.seed(42)
10+
11+
comm = MPI.COMM_WORLD
12+
rank = comm.Get_rank()
13+
nProcs = comm.Get_size()
14+
15+
16+
P_prime = int(math.ceil(math.sqrt(nProcs)))
17+
C = int(math.ceil(nProcs / P_prime))
18+
assert P_prime * C >= nProcs
19+
20+
# matrix dims
21+
M = 32 # any M
22+
K = 32 # any K
23+
N = 35 # any N
24+
25+
blk_rows = int(math.ceil(M / P_prime))
26+
blk_cols = int(math.ceil(N / P_prime))
27+
28+
my_group = rank % P_prime
29+
my_layer = rank // P_prime
30+
31+
# sub‐communicators
32+
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
33+
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
34+
35+
# Each rank will end up with:
36+
# A_p: shape (my_own_rows, K)
37+
# B_p: shape (K, my_own_cols)
38+
# where
39+
row_start = my_group * blk_rows
40+
row_end = min(M, row_start + blk_rows)
41+
my_own_rows = row_end - row_start
42+
43+
col_start = my_group * blk_cols # note: same my_group index on cols
44+
col_end = min(N, col_start + blk_cols)
45+
my_own_cols = col_end - col_start
46+
47+
# ======================= BROADCASTING THE SLICES =======================
48+
if rank == 0:
49+
A = np.arange(M*K, dtype=np.float32).reshape(M, K)
50+
B = np.arange(K*N, dtype=np.float32).reshape(K, N)
51+
for dest in range(nProcs):
52+
pg = dest % P_prime
53+
rs = pg*blk_rows; re = min(M, rs+blk_rows)
54+
cs = pg*blk_cols; ce = min(N, cs+blk_cols)
55+
a_block , b_block = A[rs:re, :].copy(), B[:, cs:ce].copy()
56+
if dest == 0:
57+
A_p, B_p = a_block, b_block
58+
else:
59+
comm.Send(a_block, dest=dest, tag=100+dest)
60+
comm.Send(b_block, dest=dest, tag=200+dest)
61+
else:
62+
A_p = np.empty((my_own_rows, K), dtype=np.float32)
63+
B_p = np.empty((K, my_own_cols), dtype=np.float32)
64+
comm.Recv(A_p, source=0, tag=100+rank)
65+
comm.Recv(B_p, source=0, tag=200+rank)
66+
67+
comm.Barrier()
68+
69+
Aop = SUMMAMatrixMult(A_p, N)
70+
col_lens = comm.allgather(my_own_cols)
71+
total_cols = np.add.reduce(col_lens, 0)
72+
x = DistributedArray(global_shape=K * total_cols,
73+
local_shapes=[K * col_len for col_len in col_lens],
74+
partition=Partition.SCATTER,
75+
mask=[i % P_prime for i in range(comm.Get_size())],
76+
dtype=np.float32)
77+
x[:] = B_p.flatten()
78+
y = Aop @ x
79+
80+
# ======================= VERIFICATION =================-=============
81+
A = np.arange(M*K).reshape(M, K).astype(np.float32)
82+
B = np.arange(K*N).reshape(K, N).astype(np.float32)
83+
C_true = A @ B
84+
Z_true = (A.T.dot(C_true.conj())).conj()
85+
86+
87+
col_start = my_layer * blk_cols # note: same my_group index on cols
88+
col_end = min(N, col_start + blk_cols)
89+
my_own_cols = col_end - col_start
90+
expected_y = C_true[:,col_start:col_end].flatten()
91+
92+
if not np.allclose(y.local_array, expected_y, atol=1e-6, rtol=1e-14):
93+
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
94+
print(f'{rank} local: {y.local_array}, expected: {C_true[:,col_start:col_end]}')
95+
else:
96+
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
97+
98+
z = Aop.H @ y
99+
expected_z = Z_true[:,col_start:col_end].flatten()
100+
if not np.allclose(z.local_array, expected_z, atol=1e-6, rtol=1e-14):
101+
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
102+
print(f'{rank} local: {z.local_array}, expected: {Z_true[:,col_start:col_end]}')
103+
else:
104+
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import numpy as np
2+
import math
3+
from mpi4py import MPI
4+
from pylops.utils.backend import get_module
5+
from pylops.utils.typing import DTypeLike, NDArray
6+
7+
from pylops_mpi import (
8+
DistributedArray,
9+
MPILinearOperator,
10+
Partition
11+
)
12+
13+
14+
class SUMMAMatrixMult(MPILinearOperator):
15+
def __init__(
16+
self,
17+
A: NDArray,
18+
N: int,
19+
base_comm: MPI.Comm = MPI.COMM_WORLD,
20+
dtype: DTypeLike = "float64",
21+
) -> None:
22+
rank = base_comm.Get_rank()
23+
size = base_comm.Get_size()
24+
25+
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
26+
self._P_prime = int(math.ceil(math.sqrt(size)))
27+
self._C = int(math.ceil(size / self._P_prime))
28+
assert self._P_prime * self._C >= size
29+
30+
# Compute this process's group and layer indices
31+
self._group_id = rank % self._P_prime
32+
self._layer_id = rank // self._P_prime
33+
34+
# Split communicators by layer (rows) and by group (columns)
35+
self.base_comm = base_comm
36+
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
37+
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
38+
39+
self.dtype = np.dtype(dtype)
40+
self.A = np.array(A, dtype=self.dtype, copy=False)
41+
42+
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
43+
self.K = A.shape[1]
44+
self.N = N
45+
46+
# Determine how many columns each group holds
47+
block_cols = int(math.ceil(self.N / self._P_prime))
48+
local_col_start = self._group_id * block_cols
49+
local_col_end = min(self.N, local_col_start + block_cols)
50+
local_ncols = local_col_end - local_col_start
51+
52+
# Sum up the total number of input columns across all processes
53+
total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM)
54+
self.dims = (self.K, total_ncols)
55+
56+
# Recompute how many output columns each layer holds
57+
layer_col_start = self._layer_id * block_cols
58+
layer_col_end = min(self.N, layer_col_start + block_cols)
59+
layer_ncols = layer_col_end - layer_col_start
60+
total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM)
61+
62+
self.dimsd = (self.M, total_layer_cols)
63+
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
64+
65+
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
66+
67+
def _matvec(self, x: DistributedArray) -> DistributedArray:
68+
ncp = get_module(x.engine)
69+
if x.partition != Partition.SCATTER:
70+
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
71+
blk_cols = int(math.ceil(self.N / self._P_prime))
72+
col_start = self._group_id * blk_cols
73+
col_end = min(self.N, col_start + blk_cols)
74+
my_own_cols = max(0, col_end - col_start)
75+
x = x.local_array.reshape((self.dims[0], my_own_cols))
76+
x = x.astype(self.dtype, copy=False)
77+
B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id)
78+
C_local = ncp.vstack(
79+
self._layer_comm.allgather(
80+
ncp.matmul(self.A, B_block)
81+
)
82+
)
83+
84+
layer_col_start = self._layer_id * blk_cols
85+
layer_col_end = min(self.N, layer_col_start + blk_cols)
86+
layer_ncols = max(0, layer_col_end - layer_col_start)
87+
layer_col_lens = self.base_comm.allgather(layer_ncols)
88+
mask = [i // self._P_prime for i in range(self.size)]
89+
90+
y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
91+
local_shapes=[(self.M * c) for c in layer_col_lens],
92+
mask=mask,
93+
#axis=1,
94+
partition=Partition.SCATTER,
95+
dtype=self.dtype)
96+
y[:] = C_local.flatten()
97+
return y
98+
99+
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
100+
ncp = get_module(x.engine)
101+
if x.partition != Partition.SCATTER:
102+
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
103+
104+
# Determine local column block for this layer
105+
blk_cols = int(math.ceil(self.N / self._P_prime))
106+
layer_col_start = self._layer_id * blk_cols
107+
layer_col_end = min(self.N, layer_col_start + blk_cols)
108+
layer_ncols = layer_col_end - layer_col_start
109+
layer_col_lens = self.base_comm.allgather(layer_ncols)
110+
x = x.local_array.reshape((self.M, layer_ncols))
111+
112+
# Determine local row block for this process group
113+
blk_rows = int(math.ceil(self.M / self._P_prime))
114+
row_start = self._group_id * blk_rows
115+
row_end = min(self.M, row_start + blk_rows)
116+
117+
B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False)
118+
A_local = self.A.T.conj()
119+
120+
m, b = A_local.shape
121+
pad = (-m) % self._P_prime
122+
r = (m + pad) // self._P_prime
123+
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0)
124+
A_batch = A_pad.reshape(self._P_prime, r, b)
125+
126+
# Perform local matmul and unpad
127+
Y_batch = ncp.matmul(A_batch, B_tile)
128+
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
129+
y_local = Y_pad[:m, :]
130+
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)
131+
132+
mask = [i // self._P_prime for i in range(self.size)]
133+
y = DistributedArray(
134+
global_shape=(self.K * self.dimsd[1]),
135+
local_shapes=[self.K * c for c in layer_col_lens],
136+
mask=mask,
137+
#axis=1
138+
partition=Partition.SCATTER,
139+
dtype=self.dtype,
140+
)
141+
y[:] = y_layer.flatten()
142+
return y

0 commit comments

Comments
 (0)