From 13025d74f18e6c8c97e0de27f959f6637ea3bd8c Mon Sep 17 00:00:00 2001 From: mrava87 Date: Wed, 2 Jul 2025 21:45:19 +0000 Subject: [PATCH] minor: moved active_grid_comm before _matvec --- pylops_mpi/basicoperators/MatrixMult.py | 54 ++++++++++++------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py index 0dcee587..39eda45e 100644 --- a/pylops_mpi/basicoperators/MatrixMult.py +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -163,37 +163,12 @@ def __init__( shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) - def _matvec(self, x: DistributedArray) -> DistributedArray: - ncp = get_module(x.engine) - if x.partition != Partition.SCATTER: - raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") - - y = DistributedArray( - global_shape=(self.N * self.dimsd[1]), - local_shapes=[(self.N * c) for c in self._rank_col_lens], - mask=x.mask, - partition=Partition.SCATTER, - dtype=self.dtype, - base_comm=self.base_comm - ) - - my_own_cols = self._rank_col_lens[self.rank] - x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) - X_local = x_arr.astype(self.dtype) - Y_local = ncp.vstack( - self._row_comm.allgather( - ncp.matmul(self.A, X_local) - ) - ) - y[:] = Y_local.flatten() - return y - @staticmethod def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): r"""Configure active grid Configure a square process grid from a parent MPI communicator and - select the subset of "active" processes. Each process in ``base_comm`` + select a subset of "active" processes. Each process in ``base_comm`` is assigned to a logical 2D grid of size :math:`P' \times P'`, where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first :math:`active_dim x active_dim` processes @@ -218,7 +193,7 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): if inactive). row : :obj:`int` Grid row index of this process in the active grid (or original rank - if inactive). + if inactive). col : :obj:`int` Grid column index of this process in the active grid (or original rank if inactive). @@ -246,6 +221,31 @@ def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): return new_comm, new_rank, new_row, new_col, True + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + + y = DistributedArray( + global_shape=(self.N * self.dimsd[1]), + local_shapes=[(self.N * c) for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + base_comm=self.base_comm + ) + + my_own_cols = self._rank_col_lens[self.rank] + x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) + X_local = x_arr.astype(self.dtype) + Y_local = ncp.vstack( + self._row_comm.allgather( + ncp.matmul(self.A, X_local) + ) + ) + y[:] = Y_local.flatten() + return y + def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition != Partition.SCATTER: