|
| 1 | +import numpy as np |
| 2 | +import math |
| 3 | +from mpi4py import MPI |
| 4 | +from pylops.utils.backend import get_module |
| 5 | +from pylops.utils.typing import DTypeLike, NDArray |
| 6 | + |
| 7 | +from pylops_mpi import ( |
| 8 | + DistributedArray, |
| 9 | + MPILinearOperator, |
| 10 | + Partition |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +class SUMMAMatrixMult(MPILinearOperator): |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + A: NDArray, |
| 18 | + N: int, |
| 19 | + base_comm: MPI.Comm = MPI.COMM_WORLD, |
| 20 | + dtype: DTypeLike = "float64", |
| 21 | + ) -> None: |
| 22 | + rank = base_comm.Get_rank() |
| 23 | + size = base_comm.Get_size() |
| 24 | + |
| 25 | + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size |
| 26 | + self._P_prime = int(math.ceil(math.sqrt(size))) |
| 27 | + self._C = int(math.ceil(size / self._P_prime)) |
| 28 | + assert self._P_prime * self._C >= size |
| 29 | + |
| 30 | + # Compute this process's group and layer indices |
| 31 | + self._group_id = rank % self._P_prime |
| 32 | + self._layer_id = rank // self._P_prime |
| 33 | + |
| 34 | + # Split communicators by layer (rows) and by group (columns) |
| 35 | + self.base_comm = base_comm |
| 36 | + self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id) |
| 37 | + self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id) |
| 38 | + |
| 39 | + self.dtype = np.dtype(dtype) |
| 40 | + self.A = np.array(A, dtype=self.dtype, copy=False) |
| 41 | + |
| 42 | + self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM) |
| 43 | + self.K = A.shape[1] |
| 44 | + self.N = N |
| 45 | + |
| 46 | + # Determine how many columns each group holds |
| 47 | + block_cols = int(math.ceil(self.N / self._P_prime)) |
| 48 | + local_col_start = self._group_id * block_cols |
| 49 | + local_col_end = min(self.N, local_col_start + block_cols) |
| 50 | + local_ncols = local_col_end - local_col_start |
| 51 | + |
| 52 | + # Sum up the total number of input columns across all processes |
| 53 | + total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM) |
| 54 | + self.dims = (self.K, total_ncols) |
| 55 | + |
| 56 | + # Recompute how many output columns each layer holds |
| 57 | + layer_col_start = self._layer_id * block_cols |
| 58 | + layer_col_end = min(self.N, layer_col_start + block_cols) |
| 59 | + layer_ncols = layer_col_end - layer_col_start |
| 60 | + total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM) |
| 61 | + |
| 62 | + self.dimsd = (self.M, total_layer_cols) |
| 63 | + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) |
| 64 | + |
| 65 | + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) |
| 66 | + |
| 67 | + def _matvec(self, x: DistributedArray) -> DistributedArray: |
| 68 | + ncp = get_module(x.engine) |
| 69 | + if x.partition != Partition.SCATTER: |
| 70 | + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") |
| 71 | + blk_cols = int(math.ceil(self.N / self._P_prime)) |
| 72 | + col_start = self._group_id * blk_cols |
| 73 | + col_end = min(self.N, col_start + blk_cols) |
| 74 | + my_own_cols = max(0, col_end - col_start) |
| 75 | + x = x.local_array.reshape((self.dims[0], my_own_cols)) |
| 76 | + x = x.astype(self.dtype, copy=False) |
| 77 | + B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id) |
| 78 | + C_local = ncp.vstack( |
| 79 | + self._layer_comm.allgather( |
| 80 | + ncp.matmul(self.A, B_block) |
| 81 | + ) |
| 82 | + ) |
| 83 | + |
| 84 | + layer_col_start = self._layer_id * blk_cols |
| 85 | + layer_col_end = min(self.N, layer_col_start + blk_cols) |
| 86 | + layer_ncols = max(0, layer_col_end - layer_col_start) |
| 87 | + layer_col_lens = self.base_comm.allgather(layer_ncols) |
| 88 | + mask = [i // self._P_prime for i in range(self.size)] |
| 89 | + |
| 90 | + y = DistributedArray(global_shape= (self.M * self.dimsd[1]), |
| 91 | + local_shapes=[(self.M * c) for c in layer_col_lens], |
| 92 | + mask=mask, |
| 93 | + #axis=1, |
| 94 | + partition=Partition.SCATTER, |
| 95 | + dtype=self.dtype) |
| 96 | + y[:] = C_local.flatten() |
| 97 | + return y |
| 98 | + |
| 99 | + def _rmatvec(self, x: DistributedArray) -> DistributedArray: |
| 100 | + ncp = get_module(x.engine) |
| 101 | + if x.partition != Partition.SCATTER: |
| 102 | + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") |
| 103 | + |
| 104 | + # Determine local column block for this layer |
| 105 | + blk_cols = int(math.ceil(self.N / self._P_prime)) |
| 106 | + layer_col_start = self._layer_id * blk_cols |
| 107 | + layer_col_end = min(self.N, layer_col_start + blk_cols) |
| 108 | + layer_ncols = layer_col_end - layer_col_start |
| 109 | + layer_col_lens = self.base_comm.allgather(layer_ncols) |
| 110 | + x = x.local_array.reshape((self.M, layer_ncols)) |
| 111 | + |
| 112 | + # Determine local row block for this process group |
| 113 | + blk_rows = int(math.ceil(self.M / self._P_prime)) |
| 114 | + row_start = self._group_id * blk_rows |
| 115 | + row_end = min(self.M, row_start + blk_rows) |
| 116 | + |
| 117 | + B_tile = x[row_start:row_end, :].astype(self.dtype, copy=False) |
| 118 | + A_local = self.A.T.conj() |
| 119 | + |
| 120 | + m, b = A_local.shape |
| 121 | + pad = (-m) % self._P_prime |
| 122 | + r = (m + pad) // self._P_prime |
| 123 | + A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=0) |
| 124 | + A_batch = A_pad.reshape(self._P_prime, r, b) |
| 125 | + |
| 126 | + # Perform local matmul and unpad |
| 127 | + Y_batch = ncp.matmul(A_batch, B_tile) |
| 128 | + Y_pad = Y_batch.reshape(r * self._P_prime, -1) |
| 129 | + y_local = Y_pad[:m, :] |
| 130 | + y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM) |
| 131 | + |
| 132 | + mask = [i // self._P_prime for i in range(self.size)] |
| 133 | + y = DistributedArray( |
| 134 | + global_shape=(self.K * self.dimsd[1]), |
| 135 | + local_shapes=[self.K * c for c in layer_col_lens], |
| 136 | + mask=mask, |
| 137 | + #axis=1 |
| 138 | + partition=Partition.SCATTER, |
| 139 | + dtype=self.dtype, |
| 140 | + ) |
| 141 | + y[:] = y_layer.flatten() |
| 142 | + return y |
0 commit comments