|
12 | 12 |
|
13 | 13 |
|
14 | 14 | class MPIMatrixMult(MPILinearOperator): |
| 15 | + r""" |
| 16 | + Distributed Matrix-Matrix multiplication |
| 17 | + Implements a distributed matrix-matrix multiplication |
| 18 | +
|
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + A : :obj:`numpy.ndarray` |
| 22 | + Matrix multiplication operator of size |
| 23 | + :math:`[ \times ]` |
| 24 | + saveAt : :obj:`bool`, optional |
| 25 | + Save ``A`` and ``A.H`` to speed up the computation of adjoint |
| 26 | + (``True``) or create ``A.H`` on-the-fly (``False``) |
| 27 | + Note that ``saveAt=True`` will double the amount of required memory |
| 28 | + base_comm : :obj:`mpi4py.MPI.Comm`, optional |
| 29 | + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. |
| 30 | + dtype : :obj:`str`, optional |
| 31 | + Type of elements in input array. |
| 32 | +
|
| 33 | + Notes |
| 34 | + ----- |
| 35 | + """ |
15 | 36 | def __init__( |
16 | 37 | self, |
17 | 38 | A: NDArray, |
@@ -102,17 +123,8 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: |
102 | 123 |
|
103 | 124 | x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype) |
104 | 125 | X_tile = x_arr[self._row_start:self._row_end, :] |
105 | | - |
106 | 126 | A_local = self.At if hasattr(self, "At") else self.A.T.conj() |
107 | | - m, b = A_local.shape |
108 | | - pad = (-m) % self._P_prime |
109 | | - A_pad = A_local if pad <= 0 else np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0)) |
110 | | - batch_sz = (m + pad) // self._P_prime |
111 | | - A_batch = A_pad.reshape(self._P_prime, batch_sz, b) |
112 | | - |
113 | | - Y_batch = ncp.matmul(A_batch, X_tile) |
114 | | - Y_pad = Y_batch.reshape(batch_sz * self._P_prime, -1) |
115 | | - y_local = Y_pad[:A_local.shape[0], :] |
116 | | - y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM) |
| 127 | + Y_batch = ncp.matmul(A_local, X_tile) |
| 128 | + y_layer = self._layer_comm.allreduce(Y_batch, op=MPI.SUM) |
117 | 129 | y[:] = y_layer.flatten() |
118 | 130 | return y |
0 commit comments