| 
 | 1 | +import numpy as np  | 
 | 2 | + | 
 | 3 | +from mpi4py import MPI  | 
 | 4 | +from pylops.utils.backend import get_module  | 
 | 5 | +from pylops.utils.typing import DTypeLike, NDArray  | 
 | 6 | + | 
 | 7 | +from pylops_mpi import (  | 
 | 8 | +    DistributedArray,  | 
 | 9 | +    MPILinearOperator,  | 
 | 10 | +    Partition  | 
 | 11 | +)  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +class MPIFredholm1(MPILinearOperator):  | 
 | 15 | +    r"""Fredholm integral of first kind.  | 
 | 16 | +
  | 
 | 17 | +    Implement a multi-dimensional Fredholm integral of first kind distributed  | 
 | 18 | +    across the first dimension  | 
 | 19 | +
  | 
 | 20 | +    Parameters  | 
 | 21 | +    ----------  | 
 | 22 | +    G : :obj:`numpy.ndarray`  | 
 | 23 | +        Multi-dimensional convolution kernel of size  | 
 | 24 | +        :math:`[n_{\text{slice}} \times n_x \times n_y]`  | 
 | 25 | +    nz : :obj:`int`, optional  | 
 | 26 | +        Additional dimension of model  | 
 | 27 | +    saveGt : :obj:`bool`, optional  | 
 | 28 | +        Save ``G`` and ``G.H`` to speed up the computation of adjoint  | 
 | 29 | +        (``True``) or create ``G.H`` on-the-fly (``False``)  | 
 | 30 | +        Note that ``saveGt=True`` will double the amount of required memory  | 
 | 31 | +    usematmul : :obj:`bool`, optional  | 
 | 32 | +        Use :func:`numpy.matmul` (``True``) or for-loop with :func:`numpy.dot`  | 
 | 33 | +        (``False``). As it is not possible to define which approach is more  | 
 | 34 | +        performant (this is highly dependent on the size of ``G`` and input  | 
 | 35 | +        arrays as well as the hardware used in the computation), we advise users  | 
 | 36 | +        to time both methods for their specific problem prior to making a  | 
 | 37 | +        choice.  | 
 | 38 | +    base_comm : :obj:`mpi4py.MPI.Comm`, optional  | 
 | 39 | +        MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.  | 
 | 40 | +    dtype : :obj:`str`, optional  | 
 | 41 | +        Type of elements in input array.  | 
 | 42 | +
  | 
 | 43 | +    Attributes  | 
 | 44 | +    ----------  | 
 | 45 | +    shape : :obj:`tuple`  | 
 | 46 | +        Operator shape  | 
 | 47 | +
  | 
 | 48 | +    Raises  | 
 | 49 | +    ------  | 
 | 50 | +    NotImplementedError  | 
 | 51 | +        If the size of the first dimension of ``G`` is equal to 1 in any of the ranks  | 
 | 52 | +
  | 
 | 53 | +    Notes  | 
 | 54 | +    -----  | 
 | 55 | +    A multi-dimensional Fredholm integral of first kind can be expressed as  | 
 | 56 | +
  | 
 | 57 | +    .. math::  | 
 | 58 | +
  | 
 | 59 | +        d(k, x, z) = \int{G(k, x, y) m(k, y, z) \,\mathrm{d}y}  | 
 | 60 | +        \quad \forall k=1,\ldots,n_{slice}  | 
 | 61 | +
  | 
 | 62 | +    on the other hand its adjoint is expressed as  | 
 | 63 | +
  | 
 | 64 | +    .. math::  | 
 | 65 | +
  | 
 | 66 | +        m(k, y, z) = \int{G^*(k, y, x) d(k, x, z) \,\mathrm{d}x}  | 
 | 67 | +        \quad \forall k=1,\ldots,n_{\text{slice}}  | 
 | 68 | +
  | 
 | 69 | +    This integral is implemented in a distributed fashion, where ``G``  | 
 | 70 | +    is split across ranks along its first dimension. The inputs  | 
 | 71 | +    of both the forward and adjoint are distributed arrays with broadcast partion:  | 
 | 72 | +    each rank takes a portion of such arrays, computes a partial integral, and  | 
 | 73 | +    the resulting outputs are then gathered by all ranks to return a  | 
 | 74 | +    distributed arrays with broadcast partion.  | 
 | 75 | +
  | 
 | 76 | +    """  | 
 | 77 | + | 
 | 78 | +    def __init__(  | 
 | 79 | +        self,  | 
 | 80 | +        G: NDArray,  | 
 | 81 | +        nz: int = 1,  | 
 | 82 | +        saveGt: bool = False,  | 
 | 83 | +        usematmul: bool = True,  | 
 | 84 | +        base_comm: MPI.Comm = MPI.COMM_WORLD,  | 
 | 85 | +        dtype: DTypeLike = "float64",  | 
 | 86 | +    ) -> None:  | 
 | 87 | +        self.nz = nz  | 
 | 88 | +        self.nsl, self.nx, self.ny = G.shape  | 
 | 89 | +        self.nsls = base_comm.allgather(self.nsl)  | 
 | 90 | +        if base_comm.Get_rank() == 0 and 1 in self.nsls:  | 
 | 91 | +            raise NotImplementedError(f'All ranks must have at least 2 or more '  | 
 | 92 | +                                      f'elements in the first dimension: '  | 
 | 93 | +                                      f'local split is instead {self.nsls}...')  | 
 | 94 | +        nslstot = base_comm.allreduce(self.nsl)  | 
 | 95 | +        self.islstart = np.insert(np.cumsum(self.nsls)[:-1], 0, 0)  | 
 | 96 | +        self.islend = np.cumsum(self.nsls)  | 
 | 97 | +        self.rank = base_comm.Get_rank()  | 
 | 98 | +        self.dims = (nslstot, self.ny, self.nz)  | 
 | 99 | +        self.dimsd = (nslstot, self.nx, self.nz)  | 
 | 100 | +        shape = (np.prod(self.dimsd),  | 
 | 101 | +                 np.prod(self.dims))  | 
 | 102 | +        super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)  | 
 | 103 | + | 
 | 104 | +        self.G = G  | 
 | 105 | +        if saveGt:  | 
 | 106 | +            self.GT = G.transpose((0, 2, 1)).conj()  | 
 | 107 | +        self.usematmul = usematmul  | 
 | 108 | + | 
 | 109 | +    def _matvec(self, x: DistributedArray) -> DistributedArray:  | 
 | 110 | +        ncp = get_module(x.engine)  | 
 | 111 | +        if x.partition is not Partition.BROADCAST:  | 
 | 112 | +            raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")  | 
 | 113 | +        y = DistributedArray(global_shape=self.shape[0], partition=Partition.BROADCAST,  | 
 | 114 | +                             engine=x.engine, dtype=self.dtype)  | 
 | 115 | +        x = x.local_array.reshape(self.dims).squeeze()  | 
 | 116 | +        x = x[self.islstart[self.rank]:self.islend[self.rank]]  | 
 | 117 | +        # apply matmul for portion of the rank of interest  | 
 | 118 | +        if self.usematmul:  | 
 | 119 | +            if self.nz == 1:  | 
 | 120 | +                x = x[..., ncp.newaxis]  | 
 | 121 | +            y1 = ncp.matmul(self.G, x)  | 
 | 122 | +        else:  | 
 | 123 | +            y1 = ncp.squeeze(ncp.zeros((self.nsls[self.rank], self.nx, self.nz), dtype=self.dtype))  | 
 | 124 | +            for isl in range(self.nsls[self.rank]):  | 
 | 125 | +                y1[isl] = ncp.dot(self.G[isl], x[isl])  | 
 | 126 | +        # gather results  | 
 | 127 | +        y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()  | 
 | 128 | +        return y  | 
 | 129 | + | 
 | 130 | +    def _rmatvec(self, x: NDArray) -> NDArray:  | 
 | 131 | +        ncp = get_module(x.engine)  | 
 | 132 | +        if x.partition is not Partition.BROADCAST:  | 
 | 133 | +            raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}")  | 
 | 134 | +        y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST,  | 
 | 135 | +                             engine=x.engine, dtype=self.dtype)  | 
 | 136 | +        x = x.local_array.reshape(self.dimsd).squeeze()  | 
 | 137 | +        x = x[self.islstart[self.rank]:self.islend[self.rank]]  | 
 | 138 | +        # apply matmul for portion of the rank of interest  | 
 | 139 | +        if self.usematmul:  | 
 | 140 | +            if self.nz == 1:  | 
 | 141 | +                x = x[..., ncp.newaxis]  | 
 | 142 | +            if hasattr(self, "GT"):  | 
 | 143 | +                y1 = ncp.matmul(self.GT, x)  | 
 | 144 | +            else:  | 
 | 145 | +                y1 = (  | 
 | 146 | +                    ncp.matmul(x.transpose(0, 2, 1).conj(), self.G)  | 
 | 147 | +                    .transpose(0, 2, 1)  | 
 | 148 | +                    .conj()  | 
 | 149 | +                )  | 
 | 150 | +        else:  | 
 | 151 | +            y1 = ncp.squeeze(ncp.zeros((self.nsls[self.rank], self.ny, self.nz), dtype=self.dtype))  | 
 | 152 | +            if hasattr(self, "GT"):  | 
 | 153 | +                for isl in range(self.nsls[self.rank]):  | 
 | 154 | +                    y1[isl] = ncp.dot(self.GT[isl], x[isl])  | 
 | 155 | +            else:  | 
 | 156 | +                for isl in range(self.nsl):  | 
 | 157 | +                    y1[isl] = ncp.dot(x[isl].T.conj(), self.G[isl]).T.conj()  | 
 | 158 | + | 
 | 159 | +        # gather results  | 
 | 160 | +        y[:] = np.vstack(self.base_comm.allgather(y1)).ravel()  | 
 | 161 | +        return y  | 
0 commit comments