1+ import numpy as np
2+
3+ from mpi4py import MPI
4+ from pylops .utils .backend import get_module
5+ from pylops .utils .typing import DTypeLike , NDArray
6+
7+ from pylops_mpi import (
8+ DistributedArray ,
9+ MPILinearOperator ,
10+ Partition
11+ )
12+
13+ from pylops_mpi .utils .decorators import reshaped
14+
15+
16+ class MPIFredholm1 (MPILinearOperator ):
17+ r"""Fredholm integral of first kind.
18+
19+ Implement a multi-dimensional Fredholm integral of first kind distributed
20+ across the first dimension
21+
22+ Parameters
23+ ----------
24+ G : :obj:`numpy.ndarray`
25+ Multi-dimensional convolution kernel of size
26+ :math:`[n_{\text{slice}} \times n_x \times n_y]`
27+ nz : :obj:`int`, optional
28+ Additional dimension of model
29+ saveGt : :obj:`bool`, optional
30+ Save ``G`` and ``G.H`` to speed up the computation of adjoint
31+ (``True``) or create ``G.H`` on-the-fly (``False``)
32+ Note that ``saveGt=True`` will double the amount of required memory
33+ usematmul : :obj:`bool`, optional
34+ Use :func:`numpy.matmul` (``True``) or for-loop with :func:`numpy.dot`
35+ (``False``). As it is not possible to define which approach is more
36+ performant (this is highly dependent on the size of ``G`` and input
37+ arrays as well as the hardware used in the computation), we advise users
38+ to time both methods for their specific problem prior to making a
39+ choice.
40+ base_comm : :obj:`mpi4py.MPI.Comm`, optional
41+ MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
42+ dtype : :obj:`str`, optional
43+ Type of elements in input array.
44+
45+ Attributes
46+ ----------
47+ shape : :obj:`tuple`
48+ Operator shape
49+
50+ Raises
51+ ------
52+ NotImplementedError
53+ If the size of the first dimension of ``G`` is equal to 1 in any of the ranks
54+
55+ Notes
56+ -----
57+ A multi-dimensional Fredholm integral of first kind can be expressed as
58+
59+ .. math::
60+
61+ d(k, x, z) = \int{G(k, x, y) m(k, y, z) \,\mathrm{d}y}
62+ \quad \forall k=1,\ldots,n_{slice}
63+
64+ on the other hand its adjoint is expressed as
65+
66+ .. math::
67+
68+ m(k, y, z) = \int{G^*(k, y, x) d(k, x, z) \,\mathrm{d}x}
69+ \quad \forall k=1,\ldots,n_{\text{slice}}
70+
71+ This integral is implemented in a distributed fashion, where ``G``
72+ is split across ranks along its first dimension. The inputs
73+ of both the forward and adjoint are distributed arrays with broadcast partion:
74+ each rank takes a portion of such arrays, computes a partial integral, and
75+ the resulting outputs are then gathered by all ranks to return a
76+ distributed arrays with broadcast partion.
77+
78+ """
79+
80+ def __init__ (
81+ self ,
82+ G : NDArray ,
83+ nz : int = 1 ,
84+ saveGt : bool = False ,
85+ usematmul : bool = True ,
86+ base_comm : MPI .Comm = MPI .COMM_WORLD ,
87+ dtype : DTypeLike = "float64" ,
88+ ) -> None :
89+ self .nz = nz
90+ self .nsl , self .nx , self .ny = G .shape
91+ self .nsls = base_comm .allgather (self .nsl )
92+ if base_comm .Get_rank () == 0 and 1 in self .nsls :
93+ raise NotImplementedError (f'All ranks must have at least 2 or more '
94+ f'elements in the first dimension: '
95+ f'local split is instead { self .nsls } ...' )
96+ nslstot = base_comm .allreduce (self .nsl )
97+ self .islstart = np .insert (np .cumsum (self .nsls )[:- 1 ], 0 , 0 )
98+ self .islend = np .cumsum (self .nsls )
99+ self .rank = base_comm .Get_rank ()
100+ self .dims = (nslstot , self .ny , self .nz )
101+ self .dimsd = (nslstot , self .nx , self .nz )
102+ shape = (np .prod (self .dimsd ),
103+ np .prod (self .dims ))
104+ super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
105+
106+ self .G = G
107+ if saveGt :
108+ self .GT = G .transpose ((0 , 2 , 1 )).conj ()
109+ self .usematmul = usematmul
110+
111+ def _matvec (self , x : DistributedArray ) -> DistributedArray :
112+ ncp = get_module (x .engine )
113+ if x .partition is not Partition .BROADCAST :
114+ raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
115+ y = DistributedArray (global_shape = self .shape [0 ], partition = Partition .BROADCAST ,
116+ engine = x .engine , dtype = self .dtype )
117+ x = x .local_array .reshape (self .dims ).squeeze ()
118+ x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
119+ # apply matmul for portion of the rank of interest
120+ if self .usematmul :
121+ if self .nz == 1 :
122+ x = x [..., ncp .newaxis ]
123+ y1 = ncp .matmul (self .G , x )
124+ else :
125+ y1 = ncp .squeeze (ncp .zeros ((self .nsls [self .rank ], self .nx , self .nz ), dtype = self .dtype ))
126+ for isl in range (self .nsls [self .rank ]):
127+ y1 [isl ] = ncp .dot (self .G [isl ], x [isl ])
128+ # gather results
129+ y [:] = np .vstack (self .base_comm .allgather (y1 )).ravel ()
130+ return y
131+
132+ def _rmatvec (self , x : NDArray ) -> NDArray :
133+ ncp = get_module (x .engine )
134+ if x .partition is not Partition .BROADCAST :
135+ raise ValueError (f"x should have partition={ Partition .BROADCAST } , { x .partition } != { Partition .BROADCAST } " )
136+ y = DistributedArray (global_shape = self .shape [1 ], partition = Partition .BROADCAST ,
137+ engine = x .engine , dtype = self .dtype )
138+ x = x .local_array .reshape (self .dimsd ).squeeze ()
139+ x = x [self .islstart [self .rank ]:self .islend [self .rank ]]
140+ # apply matmul for portion of the rank of interest
141+ if self .usematmul :
142+ if self .nz == 1 :
143+ x = x [..., ncp .newaxis ]
144+ if hasattr (self , "GT" ):
145+ y1 = ncp .matmul (self .GT , x )
146+ else :
147+ y1 = (
148+ ncp .matmul (x .transpose (0 , 2 , 1 ).conj (), self .G )
149+ .transpose (0 , 2 , 1 )
150+ .conj ()
151+ )
152+ else :
153+ y1 = ncp .squeeze (ncp .zeros ((self .nsls [self .rank ], self .ny , self .nz ), dtype = self .dtype ))
154+ if hasattr (self , "GT" ):
155+ for isl in range (self .nsls [self .rank ]):
156+ y1 [isl ] = ncp .dot (self .GT [isl ], x [isl ])
157+ else :
158+ for isl in range (self .nsl ):
159+ y1 [isl ] = ncp .dot (x [isl ].T .conj (), self .G [isl ]).T .conj ()
160+
161+ # gather results
162+ y [:] = np .vstack (self .base_comm .allgather (y1 )).ravel ()
163+ return y
164+
0 commit comments