Skip to content

Commit 8142d44

Browse files
committed
impl adjoint
1 parent a9e679e commit 8142d44

File tree

2 files changed

+113
-33
lines changed

2 files changed

+113
-33
lines changed

examples/matrixmul.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1-
from mpi4py import MPI
21
import math
3-
import pylops_mpi
4-
from pylops_mpi.DistributedArray import local_split
5-
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
62
import numpy as np
3+
from mpi4py import MPI
74

8-
9-
import numpy as np
10-
import math
11-
12-
13-
5+
import pylops_mpi
6+
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
147

158
comm = MPI.COMM_WORLD
169
rank = comm.Get_rank()
@@ -44,9 +37,16 @@
4437
print(rank, A_local.shape)
4538
Aop = MPIMatrixMult(A_local, M_new, base_comm=comm)
4639
C_dist = Aop @ B_dist
47-
C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm)
40+
Z_dist = Aop.H @ C_dist
4841

42+
C = MPIMatrixMult.block_gather(C_dist, (N_new,M_new), (N,M), comm)
43+
Z = MPIMatrixMult.block_gather(Z_dist, (K_new,M_new), (K,M), comm)
4944
if rank == 0 :
45+
print("expected:\n", np.allclose((A_data.T.dot((A_data @ B_data).conj())).conj(), Z.astype(np.int32)))
46+
# print("expected:\n", (A_data.T.dot((A_data @ B_data).conj())).conj())
47+
# print("calculated:\n",Z.astype(np.int32))
48+
# print("calculated:\n", (A_data.T.dot((A_data @ B_data).conj())).conj() == Z.astype(np.int32))
49+
5050
# print("expected:\n",np.allclose(A_data @ B_data, C))
51-
print("expected:\n", A_data @ B_data)
52-
print("calculated:\n",C)
51+
# print("expected:\n", A_data @ B_data)
52+
# print("calculated:\n",C)

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 100 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,64 @@ def __init__(
152152
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
153153
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
154154

155+
@staticmethod
156+
def active_grid_comm(base_comm: MPI.Comm, N: int, M: int):
157+
r"""Configure active grid
158+
159+
Configure a square process grid from a parent MPI communicator and
160+
select a subset of "active" processes. Each process in ``base_comm``
161+
is assigned to a logical 2D grid of size :math:`P' \times P'`,
162+
where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first
163+
:math:`active_dim x active_dim` processes
164+
(by row-major order) are considered "active". Inactive ranks return
165+
immediately with no new communicator.
166+
167+
Parameters:
168+
-----------
169+
base_comm : :obj:`mpi4py.MPI.Comm`
170+
MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``).
171+
N : :obj:`int`
172+
Number of rows of the global data domain.
173+
M : :obj:`int`
174+
Number of columns of the global data domain.
175+
176+
Returns:
177+
--------
178+
comm : :obj:`mpi4py.MPI.Comm`
179+
Sub-communicator including only active ranks.
180+
rank : :obj:`int`
181+
Rank within the new sub-communicator (or original rank
182+
if inactive).
183+
row : :obj:`int`
184+
Grid row index of this process in the active grid (or original rank
185+
if inactive).
186+
col : :obj:`int`
187+
Grid column index of this process in the active grid
188+
(or original rank if inactive).
189+
is_active : :obj:`bool`
190+
Flag indicating whether this rank is in the active sub-grid.
191+
192+
"""
193+
rank = base_comm.Get_rank()
194+
size = base_comm.Get_size()
195+
p_prime = math.isqrt(size)
196+
row, col = divmod(rank, p_prime)
197+
active_dim = min(N, M, p_prime)
198+
is_active = (row < active_dim and col < active_dim)
199+
200+
if not is_active:
201+
return None, rank, row, col, False
202+
203+
active_ranks = [r for r in range(size)
204+
if (r // p_prime) < active_dim and (r % p_prime) < active_dim]
205+
new_group = base_comm.Get_group().Incl(active_ranks)
206+
new_comm = base_comm.Create_group(new_group)
207+
p_prime_new = math.isqrt(len(active_ranks))
208+
new_rank = new_comm.Get_rank()
209+
new_row, new_col = divmod(new_rank, p_prime_new)
210+
211+
return new_comm, new_rank, new_row, new_col, True
212+
155213
@staticmethod
156214
def block_distribute(array, proc_i, proc_j, comm):
157215
p_prime = math.isqrt(comm.Get_size())
@@ -188,11 +246,12 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
188246
ncp = get_module(x.engine)
189247
if x.partition != Partition.SCATTER:
190248
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
191-
y = DistributedArray(global_shape=(self.N // self._P_prime, self.M * self._P_prime),
249+
local_shape = (self.N // self._P_prime) * ( self.M * self._P_prime // self.size)
250+
y = DistributedArray(global_shape=((self.N // self._P_prime) * self.M * self._P_prime),
192251
mask=x.mask,
252+
local_shapes=[ local_shape for _ in range(self.size)],
193253
partition=Partition.SCATTER,
194-
dtype=self.dtype,
195-
axis=1)
254+
dtype=self.dtype)
196255

197256
x = x.local_array.reshape((self.A.shape[1], -1))
198257
c_local = np.zeros((self.A.shape[0], x.shape[1]))
@@ -202,26 +261,47 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
202261
self._row_comm.Bcast(Atemp, root=k)
203262
self._col_comm.Bcast(Xtemp, root=k)
204263
c_local += ncp.dot(Atemp, Xtemp)
205-
y[:] = c_local
264+
y[:] = c_local.flatten()
206265
return y
207266

267+
208268
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
209269
ncp = get_module(x.engine)
210270
if x.partition != Partition.SCATTER:
211271
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
212-
return None
213-
# y = DistributedArray(
214-
# global_shape=(self.K * self.dimsd[1]),
215-
# local_shapes=[self.K * c for c in self._rank_col_lens],
216-
# mask=x.mask,
217-
# partition=Partition.SCATTER,
218-
# dtype=self.dtype,
219-
# )
220-
#
221-
# x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
222-
# X_tile = x_arr[self._row_start:self._row_end, :]
223-
# A_local = self.At if hasattr(self, "At") else self.A.T.conj()
224-
# Y_local = ncp.matmul(A_local, X_tile)
225-
# y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM)
226-
# y[:] = y_layer.flatten()
227-
# return y
272+
273+
local_shape = (self.K // self._P_prime) * (self.M * self._P_prime // self.size)
274+
y = DistributedArray(
275+
global_shape=((self.K // self._P_prime) * self.M * self._P_prime),
276+
mask=x.mask,
277+
local_shapes=[local_shape for _ in range(self.size)],
278+
partition=Partition.SCATTER,
279+
dtype=self.dtype,
280+
)
281+
x_reshaped = x.local_array.reshape((self.A.shape[0], -1))
282+
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
283+
c_local = np.zeros((self.A.shape[1], x_reshaped.shape[1]))
284+
P = self._P_prime
285+
286+
for k in range(P):
287+
temps = {}
288+
requests = []
289+
for buf, owner, base, name in (
290+
(A_local, self._row_id, 100, 'A'),
291+
(x_reshaped, self._col_id, 200, 'B'),
292+
):
293+
tmp = np.empty_like(buf)
294+
temps[name] = tmp
295+
src, tag = k * P + owner, (base + k) * 1000 + self.rank
296+
requests.append(self.base_comm.Irecv(tmp, source=src, tag=tag))
297+
298+
if self.rank // P == k:
299+
fixed = self.rank % P
300+
for moving in range(P):
301+
dest = (fixed * P + moving) if name == 'A' else moving * P + fixed
302+
tag = (base + k) * 1000 + dest
303+
requests.append(self.base_comm.Isend(buf, dest=dest, tag=tag))
304+
MPI.Request.Waitall(requests)
305+
c_local += ncp.dot(temps['A'], temps['B'])
306+
y[:] = c_local.flatten()
307+
return y

0 commit comments

Comments
 (0)