Skip to content

Commit 8a56096

Browse files
committed
I donot know why I thought I needed to batch
1 parent 7ac593d commit 8a56096

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@
1212

1313

1414
class MPIMatrixMult(MPILinearOperator):
15+
r"""
16+
Distributed Matrix-Matrix multiplication
17+
Implements a distributed matrix-matrix multiplication
18+
19+
Parameters
20+
----------
21+
A : :obj:`numpy.ndarray`
22+
Matrix multiplication operator of size
23+
:math:`[ \times ]`
24+
saveAt : :obj:`bool`, optional
25+
Save ``A`` and ``A.H`` to speed up the computation of adjoint
26+
(``True``) or create ``A.H`` on-the-fly (``False``)
27+
Note that ``saveAt=True`` will double the amount of required memory
28+
base_comm : :obj:`mpi4py.MPI.Comm`, optional
29+
MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
30+
dtype : :obj:`str`, optional
31+
Type of elements in input array.
32+
33+
Notes
34+
-----
35+
"""
1536
def __init__(
1637
self,
1738
A: NDArray,
@@ -102,17 +123,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
102123

103124
x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype)
104125
X_tile = x_arr[self._row_start:self._row_end, :]
105-
106126
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
107-
m, b = A_local.shape
108-
pad = (-m) % self._P_prime
109-
A_pad = A_local if pad <= 0 else np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
110-
batch_sz = (m + pad) // self._P_prime
111-
A_batch = A_pad.reshape(self._P_prime, batch_sz, b)
112-
113-
Y_batch = ncp.matmul(A_batch, X_tile)
114-
Y_pad = Y_batch.reshape(batch_sz * self._P_prime, -1)
115-
y_local = Y_pad[:A_local.shape[0], :]
116-
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)
127+
Y_local = ncp.matmul(A_local, X_tile)
128+
y_layer = self._layer_comm.allreduce(Y_local, op=MPI.SUM)
117129
y[:] = y_layer.flatten()
118130
return y

0 commit comments

Comments
 (0)