diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 3c39cc0a..66e1a373 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -42,6 +42,7 @@ Basic Operators .. autosummary:: :toctree: generated/ + MPIMatrixMult MPIBlockDiag MPIStackedBlockDiag MPIVStack diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py new file mode 100644 index 00000000..47173ba0 --- /dev/null +++ b/examples/plot_matrixmult.py @@ -0,0 +1,223 @@ +r""" +Distributed Matrix Multiplication +================================= +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` +blocked over rows (i.e., blocks of rows are stored over different ranks) and a +matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are +stored over different ranks), with equal number of row and column blocks. +Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}` +blocked in the same fashion of matrix :math:`\mathbf{X}`. + +Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly +stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is +effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where +the different blocks are flattened and stored on different ranks. Note that to +optimize communications, the ranks are organized in a 2D grid and some of the +row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are +replicated across different ranks - see below for details. + +""" + +from matplotlib import pyplot as plt +import math +import numpy as np +from mpi4py import MPI + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult + +plt.close("all") + +############################################################################### +# We set the seed such that all processes can create the input matrices filled +# with the same random number. In practical application, such matrices will be +# filled with data that is appropriate that is appropriate the use-case. +np.random.seed(42) + +############################################################################### +# We are now ready to create the input matrices :math:`\mathbf{A}` of size +# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size +# :math:`K \times N`. +N, K, M = 4, 4, 4 +A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K) +X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M) + +################################################################################ +# The processes are now arranged in a :math:`P' \times P'` grid, +# where :math:`P` is the total number of processes. +# +# We define +# +# .. math:: +# P' = \bigl \lceil \sqrt{P} \bigr \rceil +# +# and the replication factor +# +# .. math:: +# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. +# +# Each process is therefore assigned a pair of coordinates +# :math:`(r,c)` within this grid: +# +# .. math:: +# r = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor, +# \quad +# c = \mathrm{rank} \bmod P'. +# +# For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout: +# +# .. raw:: html +# +#
+# ┌────────────┬────────────┐ +# │ Rank 0 │ Rank 1 │ +# │ (r=0, c=0) │ (r=0, c=1) │ +# ├────────────┼────────────┤ +# │ Rank 2 │ Rank 3 │ +# │ (r=1, c=0) │ (r=1, c=1) │ +# └────────────┴────────────┘ +#
+# +# This is obtained by invoking the +# `:func:pylops_mpi.MPIMatrixMult.active_grid_comm` method, which is also +# responsible to identify any rank that should be deactivated (if the number +# of rows of the operator or columns of the input/output matrices are smaller +# than the row or columm ranks. + +base_comm = MPI.COMM_WORLD +comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M) +print(f"Process {base_comm.Get_rank()} is {"active" if is_active else "inactive"}") +if not is_active: exit(0) + +# Create sub‐communicators +p_prime = math.isqrt(comm.Get_size()) +row_comm = comm.Split(color=row_id, key=col_id) # all procs in same row +col_comm = comm.Split(color=col_id, key=row_id) # all procs in same col + +################################################################################ +# At this point we divide the rows and columns of :math:`\mathbf{A}` and +# :math:`\mathbf{X}`, respectively, such that each rank ends up with: +# +# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` +# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}` +# +# .. raw:: html +# +#
+# Matrix A (4 x 4): +# ┌─────────────────┐ +# │ a11 a12 a13 a14 │ <- Rows 0–1 (Process Grid Row 0) +# │ a21 a22 a23 a24 │ +# ├─────────────────┤ +# │ a41 a42 a43 a44 │ <- Rows 2–3 (Process Grid Row 1) +# │ a51 a52 a53 a54 │ +# └─────────────────┘ +#
+# +# .. raw:: html +# +#
+# Matrix X (4 x 4): +# ┌─────────┬─────────┐ +# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Process Grid Col 0), Cols 2–3 (Process Grid Col 1) +# │ b21 b22 │ b23 b24 │ +# │ b31 b32 │ b33 b34 │ +# │ b41 b42 │ b43 b44 │ +# └─────────┴─────────┘ +#
+# + +blk_rows = int(math.ceil(N / p_prime)) +blk_cols = int(math.ceil(M / p_prime)) + +rs = col_id * blk_rows +re = min(N, rs + blk_rows) +my_own_rows = max(0, re - rs) + +cs = row_id * blk_cols +ce = min(M, cs + blk_cols) +my_own_cols = max(0, ce - cs) + +A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() + +################################################################################ +# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +# operator and the input matrix :math:`\mathbf{X}` +Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32") + +col_lens = comm.allgather(my_own_cols) +total_cols = np.sum(col_lens) +x = DistributedArray(global_shape=K * total_cols, + local_shapes=[K * col_len for col_len in col_lens], + partition=Partition.SCATTER, + mask=[i % p_prime for i in range(comm.Get_size())], + base_comm=comm, + dtype="float32") +x[:] = X_p.flatten() + +################################################################################ +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively +# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`) +# Note :math:`\mathbf{Y}` is distributed in the same way as the input +# :math:`\mathbf{X}`. +y = Aop @ x + +############################################################################### +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` +# (which effectively implements a distributed matrix-matrix multiplication +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input +# :math:`\mathbf{X}`. +xadj = Aop.H @ y + +############################################################################### +# To conclude we verify our result against the equivalent serial version of +# the operation by gathering the resulting matrices in rank0 and reorganizing +# the returned 1D-arrays into 2D-arrays. + +# Local benchmarks +y = y.asarray(masked=True) +col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)] +y_blocks = [] +offset = 0 +for cnt in col_counts: + block_size = N * cnt + y_block = y[offset: offset + block_size] + if len(y_block) != 0: + y_blocks.append( + y_block.reshape(N, cnt) + ) + offset += block_size +y = np.hstack(y_blocks) + +xadj = xadj.asarray(masked=True) +xadj_blocks = [] +offset = 0 +for cnt in col_counts: + block_size = K * cnt + xadj_blk = xadj[offset: offset + block_size] + if len(xadj_blk) != 0: + xadj_blocks.append( + xadj_blk.reshape(K, cnt) + ) + offset += block_size +xadj = np.hstack(xadj_blocks) + +if rank == 0: + y_loc = (A @ X).squeeze() + xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze() + + if not np.allclose(y, y_loc, rtol=1e-6): + print("FORWARD VERIFICATION FAILED") + print(f'distributed: {y}') + print(f'expected: {y_loc}') + else: + print("FORWARD VERIFICATION PASSED") + + if not np.allclose(xadj, xadj_loc, rtol=1e-6): + print("ADJOINT VERIFICATION FAILED") + print(f'distributed: {xadj}') + print(f'expected: {xadj_loc}') + else: + print("ADJOINT VERIFICATION PASSED") diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py new file mode 100644 index 00000000..0dcee587 --- /dev/null +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -0,0 +1,269 @@ +import numpy as np +import math +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops.utils.typing import DTypeLike, NDArray + +from pylops_mpi import ( + DistributedArray, + MPILinearOperator, + Partition +) + + +class MPIMatrixMult(MPILinearOperator): + r"""MPI Matrix multiplication + + Implement distributed matrix-matrix multiplication between a matrix + :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored + over different ranks) and the input model and data vector, which are both to + be interpreted as matrices blocked over columns. + + Parameters + ---------- + A : :obj:`numpy.ndarray` + Local block of the matrix of shape :math:`[N_{loc} \times K]` + where :math:`N_{loc}` is the number of rows stored on this MPI rank and + ``K`` is the global number of columns. + M : :obj:`int` + Global leading dimension (i.e., number of columns) of the matrices + representing the input model and data vectors. + saveAt : :obj:`bool`, optional + Save ``A`` and ``A.H`` to speed up the computation of adjoint + (``True``) or create ``A.H`` on-the-fly (``False``) + Note that ``saveAt=True`` will double the amount of required memory. + Default is ``False``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + dtype : :obj:`str`, optional + Type of elements in input array. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + + Raises + ------ + Exception + If the operator is created with a non-square number of MPI ranks. + ValueError + If input vector does not have the correct partition type. + + Notes + ----- + This operator performs a matrix-matrix multiplication, whose forward + operation can be described as :math:`Y = A \cdot X` where: + + - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` + + whilst the adjoint operation is represented by + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where + :math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`. + + This implementation is based on a 1D block distribution of the operator + matrix and reshaped model and data vectors replicated across :math:`P` + processes by a factor equivalent to :math:`\sqrt{P}` across a square process + grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically: + + - The matrix ``A`` is distributed across MPI processes in a block-row fashion + and each process holds a local block of ``A`` with shape + :math:`[N_{loc} \times K]` + - The operand matrix ``X`` is distributed in a block-column fashion and + each process holds a local block of ``X`` with shape + :math:`[K \times M_{loc}]` + - Communication is minimized by using a 2D process grid layout + + **Forward Operation step-by-step** + + 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` + of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local`` + is the number of columns assigned to the current process. + + 2. **Local Computation**: Each process computes ``A_local @ X_local`` where: + - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``) + - ``X_local`` is the broadcasted operand (shape ``K x M_local``) + + 3. **Row-wise Gather**: Results from all processes in each row are gathered + using ``allgather`` to ensure that each rank has a block-column of the + output matrix. + + **Adjoint Operation step-by-step** + + The adjoint operation performs the conjugate transpose multiplication: + + 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)`` + representing the local columns of the input matrix. + + 2. **Local Adjoint Computation**: Each process computes + ``A_local.H @ X_tile`` where ``A_local.H`` is either i) Pre-computed + and stored in ``At`` (if ``saveAt=True``), ii) computed on-the-fly as + ``A.T.conj()`` (if ``saveAt=False``). Each process multiplies its + transposed local ``A`` block ``A_local^H`` (shape ``K x N_block``) + with the extracted ``X_tile`` (shape ``N_block x M_local``), + producing a partial result of shape ``(K, M_local)``. + This computes the local contribution of columns of ``A^H`` to the final + result. + + 3. **Row-wise Reduction**: Since the full result ``Y = A^H \cdot X`` is the + sum of the contributions from all column blocks of ``A^H``, processes in + the same row perform an ``allreduce`` sum to combine their partial results. + This gives the complete ``(K, M_local)`` result for their assigned column. + + """ + def __init__( + self, + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + dtype: DTypeLike = "float64", + ) -> None: + rank = base_comm.Get_rank() + size = base_comm.Get_size() + + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size + self._P_prime = math.isqrt(size) + self._C = self._P_prime + if self._P_prime * self._C != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + self._col_id = rank % self._P_prime + self._row_id = rank // self._P_prime + + self.base_comm = base_comm + self._row_comm = base_comm.Split(color=self._row_id, key=self._col_id) + self._col_comm = base_comm.Split(color=self._col_id, key=self._row_id) + + self.A = A.astype(np.dtype(dtype)) + if saveAt: + self.At = A.T.conj() + + self.N = self._row_comm.allreduce(self.A.shape[0], op=MPI.SUM) + self.K = A.shape[1] + self.M = M + + block_cols = int(math.ceil(self.M / self._P_prime)) + blk_rows = int(math.ceil(self.N / self._P_prime)) + + self._row_start = self._col_id * blk_rows + self._row_end = min(self.N, self._row_start + blk_rows) + + self._col_start = self._row_id * block_cols + self._col_end = min(self.M, self._col_start + block_cols) + + self._local_ncols = max(0, self._col_end - self._col_start) + self._rank_col_lens = self.base_comm.allgather(self._local_ncols) + total_ncols = np.sum(self._rank_col_lens) + + self.dims = (self.K, total_ncols) + self.dimsd = (self.N, total_ncols) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + + y = DistributedArray( + global_shape=(self.N * self.dimsd[1]), + local_shapes=[(self.N * c) for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + base_comm=self.base_comm + ) + + my_own_cols = self._rank_col_lens[self.rank] + x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) + X_local = x_arr.astype(self.dtype) + Y_local = ncp.vstack( + self._row_comm.allgather( + ncp.matmul(self.A, X_local) + ) + ) + y[:] = Y_local.flatten() + return y + + @staticmethod + def active_grid_comm(base_comm: MPI.Comm, N: int, M: int): + r"""Configure active grid + + Configure a square process grid from a parent MPI communicator and + select the subset of "active" processes. Each process in ``base_comm`` + is assigned to a logical 2D grid of size :math:`P' \times P'`, + where :math:`P' = \bigl \lceil \sqrt{P} \bigr \rceil`. Only the first + :math:`active_dim x active_dim` processes + (by row-major order) are considered "active". Inactive ranks return + immediately with no new communicator. + + Parameters: + ----------- + base_comm : :obj:`mpi4py.MPI.Comm` + MPI Parent Communicator. (e.g., ``mpi4py.MPI.COMM_WORLD``). + N : :obj:`int` + Number of rows of the global data domain. + M : :obj:`int` + Number of columns of the global data domain. + + Returns: + -------- + comm : :obj:`mpi4py.MPI.Comm` + Sub-communicator including only active ranks. + rank : :obj:`int` + Rank within the new sub-communicator (or original rank + if inactive). + row : :obj:`int` + Grid row index of this process in the active grid (or original rank + if inactive). + col : :obj:`int` + Grid column index of this process in the active grid + (or original rank if inactive). + is_active : :obj:`bool` + Flag indicating whether this rank is in the active sub-grid. + + """ + rank = base_comm.Get_rank() + size = base_comm.Get_size() + p_prime = math.isqrt(size) + row, col = divmod(rank, p_prime) + active_dim = min(N, M, p_prime) + is_active = (row < active_dim and col < active_dim) + + if not is_active: + return None, rank, row, col, False + + active_ranks = [r for r in range(size) + if (r // p_prime) < active_dim and (r % p_prime) < active_dim] + new_group = base_comm.Get_group().Incl(active_ranks) + new_comm = base_comm.Create_group(new_group) + p_prime_new = math.isqrt(len(active_ranks)) + new_rank = new_comm.Get_rank() + new_row, new_col = divmod(new_rank, p_prime_new) + + return new_comm, new_rank, new_row, new_col, True + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + y = DistributedArray( + global_shape=(self.K * self.dimsd[1]), + local_shapes=[self.K * c for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + base_comm=self.base_comm + ) + + x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + X_tile = x_arr[self._row_start:self._row_end, :] + A_local = self.At if hasattr(self, "At") else self.A.T.conj() + Y_local = ncp.matmul(A_local, X_tile) + y_layer = self._row_comm.allreduce(Y_local, op=MPI.SUM) + y[:] = y_layer.flatten() + return y diff --git a/pylops_mpi/basicoperators/__init__.py b/pylops_mpi/basicoperators/__init__.py index 566b59fb..8db5a988 100644 --- a/pylops_mpi/basicoperators/__init__.py +++ b/pylops_mpi/basicoperators/__init__.py @@ -7,6 +7,7 @@ functionalities using MPI. A list of operators present in pylops_mpi.basicoperators: + MPIMatrixMult Matrix Multiplication operator MPIBlockDiag Block Diagonal arrangement of PyLops operators MPIStackedBlockDiag Block Diagonal arrangement of PyLops-MPI operators MPIVStack Vertical Stacking of PyLops operators @@ -19,6 +20,7 @@ """ +from .MatrixMult import * from .BlockDiag import * from .VStack import * from .HStack import * @@ -28,6 +30,7 @@ from .Gradient import * __all__ = [ + "MPIMatrixMult", "MPIBlockDiag", "MPIStackedBlockDiag", "MPIVStack", diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py new file mode 100644 index 00000000..d7ea4c61 --- /dev/null +++ b/tests/test_matrixmult.py @@ -0,0 +1,133 @@ +"""Test the MPIMatrixMult class + Designed to run with n processes + $ mpiexec -n 10 pytest test_matrixmult.py --with-mpi +""" +import math +import numpy as np +from numpy.testing import assert_allclose +from mpi4py import MPI +import pytest + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult + +np.random.seed(42) +base_comm = MPI.COMM_WORLD +size = base_comm.Get_size() + +# Define test cases: (N, K, M, dtype_str) +# M, K, N are matrix dimensions A(N,K), B(K,M) +# P_prime will be ceil(sqrt(size)). +test_params = [ + pytest.param(37, 37, 37, "float32", id="f32_37_37_37"), + pytest.param(50, 30, 40, "float64", id="f64_50_30_40"), + pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"), + pytest.param(3, 4, 5, "float32", id="f32_3_4_5"), + pytest.param(1, 2, 1, "float64", id="f64_1_2_1",), + pytest.param(2, 1, 3, "float32", id="f32_2_1_3",), +] + + +@pytest.mark.mpi(min_size=1) +@pytest.mark.parametrize("M, K, N, dtype_str", test_params) +def test_MPIMatrixMult(N, K, M, dtype_str): + dtype = np.dtype(dtype_str) + + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + comm, rank, row_id, col_id, is_active = MPIMatrixMult.active_grid_comm(base_comm, N, M) + if not is_active: return + + size = comm.Get_size() + p_prime = math.isqrt(size) + + # Calculate local matrix dimensions + blk_rows_A = int(math.ceil(N / p_prime)) + row_start_A = col_id * blk_rows_A + row_end_A = min(N, row_start_A + blk_rows_A) + + blk_cols_X = int(math.ceil(M / p_prime)) + col_start_X = row_id * blk_cols_X + col_end_X = min(M, col_start_X + blk_cols_X) + local_col_X_len = max(0, col_end_X - col_start_X) + + A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_p = A_glob[row_start_A:row_end_A, :] + X_p = X_glob[:, col_start_X:col_end_X] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str) + + # Create DistributedArray for input x (representing B flattened) + all_local_col_len = comm.allgather(local_col_X_len) + total_cols = np.sum(all_local_col_len) + + x_dist = DistributedArray( + global_shape=(K * total_cols), + local_shapes=[K * cl_b for cl_b in all_local_col_len], + partition=Partition.SCATTER, + base_comm=comm, + mask=[i % p_prime for i in range(size)], + dtype=dtype + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ x (distributed) + y_dist = Aop @ x_dist + + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + # Re-organize in local matrix + y = y_dist.asarray(masked=True) + col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)] + y_blocks = [] + offset = 0 + for cnt in col_counts: + block_size = N * cnt + y_block = y[offset: offset + block_size] + if len(y_block) != 0: + y_blocks.append( + y_block.reshape(N, cnt) + ) + offset += block_size + y = np.hstack(y_blocks) + + xadj = xadj_dist.asarray(masked=True) + xadj_blocks = [] + offset = 0 + for cnt in col_counts: + block_size = K * cnt + xadj_blk = xadj[offset: offset + block_size] + if len(xadj_blk) != 0: + xadj_blocks.append( + xadj_blk.reshape(K, cnt) + ) + offset += block_size + xadj = np.hstack(xadj_blocks) + + if rank == 0: + y_loc = A_glob @ X_glob + assert_allclose( + y.squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob.conj().T @ y_loc + assert_allclose( + xadj.squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Ajoint verification failed." + ) \ No newline at end of file