99import numpy as np
1010
1111from pylops import LinearOperator
12- from pylops .basicoperators import BlockDiag , Diagonal , HStack , Restriction
12+ from pylops .utils ._internal import _value_or_sized_to_tuple
13+ from pylops .utils .backend import (
14+ get_array_module ,
15+ get_sliding_window_view ,
16+ to_cupy_conditional ,
17+ )
18+ from pylops .utils .decorators import reshaped
1319from pylops .utils .tapers import taper2d
1420from pylops .utils .typing import InputDimsLike , NDArray
1521
@@ -110,15 +116,7 @@ def sliding2d_design(
110116 return nwins , dims , mwins_inends , dwins_inends
111117
112118
113- def Sliding2D (
114- Op : LinearOperator ,
115- dims : InputDimsLike ,
116- dimsd : InputDimsLike ,
117- nwin : int ,
118- nover : int ,
119- tapertype : str = "hanning" ,
120- name : str = "S" ,
121- ) -> LinearOperator :
119+ class Sliding2D (LinearOperator ):
122120 """2D Sliding transform operator.
123121
124122 Apply a transform operator ``Op`` repeatedly to slices of the model
@@ -139,6 +137,12 @@ def Sliding2D(
139137 ``nover``, it is recommended to first run ``sliding2d_design`` to obtain
140138 the corresponding ``dims`` and number of windows.
141139
140+ .. note:: Two kind of operators ``Op`` can be provided: the first
141+ applies a single transformation to each window separately; the second
142+ applies the transformation to all of the windows at the same time. This
143+ is directly inferred during initialization when the following condition
144+ holds ``Op.shape[1] == np.prod(dims)``.
145+
142146 .. warning:: Depending on the choice of `nwin` and `nover` as well as the
143147 size of the data, sliding windows may not cover the entire data.
144148 The start and end indices of each window will be displayed and returned
@@ -176,47 +180,104 @@ def Sliding2D(
176180 shape (``dims``).
177181
178182 """
179- # data windows
180- dwin_ins , dwin_ends = _slidingsteps (dimsd [0 ], nwin , nover )
181- nwins = len (dwin_ins )
182-
183- # check patching
184- if nwins * Op .shape [1 ] // dims [1 ] != dims [0 ]:
185- raise ValueError (
186- f"Model shape (dims={ dims } ) is not consistent with chosen "
187- f"number of windows. Run sliding2d_design to identify the "
188- f"correct number of windows for the current "
189- "model size..."
190- )
191183
192- # create tapers
193- if tapertype is not None :
194- tap = taper2d (dimsd [1 ], nwin , nover , tapertype = tapertype ).astype (Op .dtype )
195- tapin = tap .copy ()
196- tapin [:nover ] = 1
197- tapend = tap .copy ()
198- tapend [- nover :] = 1
199- taps = {}
200- taps [0 ] = tapin
201- for i in range (1 , nwins - 1 ):
202- taps [i ] = tap
203- taps [nwins - 1 ] = tapend
204-
205- # transform to apply
206- if tapertype is None :
207- OOp = BlockDiag ([Op for _ in range (nwins )])
208- else :
209- OOp = BlockDiag (
210- [Diagonal (taps [itap ].ravel (), dtype = Op .dtype ) * Op for itap in range (nwins )]
184+ def __init__ (
185+ self ,
186+ Op : LinearOperator ,
187+ dims : InputDimsLike ,
188+ dimsd : InputDimsLike ,
189+ nwin : int ,
190+ nover : int ,
191+ tapertype : str = "hanning" ,
192+ name : str = "S" ,
193+ ) -> None :
194+
195+ dims : Tuple [int , ...] = _value_or_sized_to_tuple (dims )
196+ dimsd : Tuple [int , ...] = _value_or_sized_to_tuple (dimsd )
197+
198+ # data windows
199+ dwin_ins , dwin_ends = _slidingsteps (dimsd [0 ], nwin , nover )
200+ self .dwin_inends = (dwin_ins , dwin_ends )
201+ nwins = len (dwin_ins )
202+ self .nwin = nwin
203+ self .nover = nover
204+
205+ # check patching
206+ if nwins * Op .shape [1 ] // dims [1 ] != dims [0 ] and Op .shape [1 ] != np .prod (dims ):
207+ raise ValueError (
208+ f"Model shape (dims={ dims } ) is not consistent with chosen "
209+ f"number of windows. Run sliding2d_design to identify the "
210+ f"correct number of windows for the current "
211+ "model size..."
212+ )
213+
214+ # create tapers
215+ self .tapertype = tapertype
216+ if self .tapertype is not None :
217+ tap = taper2d (dimsd [1 ], nwin , nover , tapertype = self .tapertype )
218+ tapin = tap .copy ()
219+ tapin [:nover ] = 1
220+ tapend = tap .copy ()
221+ tapend [- nover :] = 1
222+ self .taps = [
223+ tapin [np .newaxis , :],
224+ ]
225+ for i in range (1 , nwins - 1 ):
226+ self .taps .append (tap [np .newaxis , :])
227+ self .taps .append (tapend [np .newaxis , :])
228+ self .taps = np .concatenate (self .taps , axis = 0 )
229+
230+ # check if operator is applied to all windows simultaneously
231+ self .simOp = False
232+ if Op .shape [1 ] == np .prod (dims ):
233+ self .simOp = True
234+ self .Op = Op
235+
236+ super ().__init__ (
237+ dtype = Op .dtype ,
238+ dims = (nwins , int (dims [0 ] // nwins ), dims [1 ]),
239+ dimsd = dimsd ,
240+ clinear = False ,
241+ name = name ,
211242 )
212243
213- combining = HStack (
214- [
215- Restriction (dimsd , range (win_in , win_end ), axis = 0 , dtype = Op .dtype ).H
216- for win_in , win_end in zip (dwin_ins , dwin_ends )
217- ]
218- )
219- Sop = LinearOperator (combining * OOp )
220- Sop .dims , Sop .dimsd = (nwins , int (dims [0 ] // nwins ), dims [1 ]), dimsd
221- Sop .name = name
222- return Sop
244+ @reshaped
245+ def _matvec (self , x : NDArray ) -> NDArray :
246+ ncp = get_array_module (x )
247+ if self .tapertype is not None :
248+ self .taps = to_cupy_conditional (x , self .taps )
249+ y = ncp .zeros (self .dimsd , dtype = self .dtype )
250+ if self .simOp :
251+ x = self .Op @ x
252+ for iwin0 in range (self .dims [0 ]):
253+ if self .simOp :
254+ xx = x [iwin0 ].reshape (self .nwin , self .dimsd [- 1 ])
255+ else :
256+ xx = self .Op .matvec (x [iwin0 ].ravel ()).reshape (self .nwin , self .dimsd [- 1 ])
257+ if self .tapertype is not None :
258+ xxwin = self .taps [iwin0 ] * xx
259+ else :
260+ xxwin = xx
261+ y [self .dwin_inends [0 ][iwin0 ] : self .dwin_inends [1 ][iwin0 ]] += xxwin
262+ return y
263+
264+ @reshaped
265+ def _rmatvec (self , x : NDArray ) -> NDArray :
266+ ncp = get_array_module (x )
267+ ncp_sliding_window_view = get_sliding_window_view (x )
268+ if self .tapertype is not None :
269+ self .taps = to_cupy_conditional (x , self .taps )
270+ ywins = ncp_sliding_window_view (x , self .nwin , axis = 0 )[
271+ :: self .nwin - self .nover
272+ ].transpose (0 , 2 , 1 )
273+ if self .tapertype is not None :
274+ ywins = ywins * self .taps
275+ if self .simOp :
276+ y = self .Op .H @ ywins
277+ else :
278+ y = ncp .zeros (self .dims , dtype = self .dtype )
279+ for iwin0 in range (self .dims [0 ]):
280+ y [iwin0 ] = self .Op .rmatvec (ywins [iwin0 ].ravel ()).reshape (
281+ self .dims [1 ], self .dims [2 ]
282+ )
283+ return y
0 commit comments