@@ -163,6 +163,9 @@ class Sliding2D(LinearOperator):
163163 Number of samples of overlapping part of window
164164 tapertype : :obj:`str`, optional
165165 Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
166+ savetaper: :obj:`bool`, optional
167+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
168+ The first option is more computationally efficient, whilst the second is more memory efficient.
166169 name : :obj:`str`, optional
167170 .. versionadded:: 2.0.0
168171
@@ -189,6 +192,7 @@ def __init__(
189192 nwin : int ,
190193 nover : int ,
191194 tapertype : str = "hanning" ,
195+ savetaper : bool = True ,
192196 name : str = "S" ,
193197 ) -> None :
194198
@@ -213,19 +217,25 @@ def __init__(
213217
214218 # create tapers
215219 self .tapertype = tapertype
220+ self .savetaper = savetaper
216221 if self .tapertype is not None :
217222 tap = taper2d (dimsd [1 ], nwin , nover , tapertype = self .tapertype )
218223 tapin = tap .copy ()
219224 tapin [:nover ] = 1
220225 tapend = tap .copy ()
221226 tapend [- nover :] = 1
222- self .taps = [
223- tapin [np .newaxis , :],
224- ]
225- for i in range (1 , nwins - 1 ):
226- self .taps .append (tap [np .newaxis , :])
227- self .taps .append (tapend [np .newaxis , :])
228- self .taps = np .concatenate (self .taps , axis = 0 )
227+ if self .savetaper :
228+ self .taps = [
229+ tapin [np .newaxis , :],
230+ ]
231+ for i in range (1 , nwins - 1 ):
232+ self .taps .append (tap [np .newaxis , :])
233+ self .taps .append (tapend [np .newaxis , :])
234+ self .taps = np .concatenate (self .taps , axis = 0 )
235+ else :
236+ self .taps = np .vstack (
237+ [tapin [np .newaxis , :], tap [np .newaxis , :], tapend [np .newaxis , :]]
238+ )
229239
230240 # check if operator is applied to all windows simultaneously
231241 self .simOp = False
@@ -241,8 +251,10 @@ def __init__(
241251 name = name ,
242252 )
243253
254+ self ._register_multiplications (self .savetaper )
255+
244256 @reshaped
245- def _matvec (self , x : NDArray ) -> NDArray :
257+ def _matvec_savetaper (self , x : NDArray ) -> NDArray :
246258 ncp = get_array_module (x )
247259 if self .tapertype is not None :
248260 self .taps = to_cupy_conditional (x , self .taps )
@@ -262,7 +274,7 @@ def _matvec(self, x: NDArray) -> NDArray:
262274 return y
263275
264276 @reshaped
265- def _rmatvec (self , x : NDArray ) -> NDArray :
277+ def _rmatvec_savetaper (self , x : NDArray ) -> NDArray :
266278 ncp = get_array_module (x )
267279 ncp_sliding_window_view = get_sliding_window_view (x )
268280 if self .tapertype is not None :
@@ -281,3 +293,80 @@ def _rmatvec(self, x: NDArray) -> NDArray:
281293 self .dims [1 ], self .dims [2 ]
282294 )
283295 return y
296+
297+ @reshaped
298+ def _matvec_nosavetaper (self , x : NDArray ) -> NDArray :
299+ ncp = get_array_module (x )
300+ if self .tapertype is not None :
301+ self .taps = to_cupy_conditional (x , self .taps )
302+ y = ncp .zeros (self .dimsd , dtype = self .dtype )
303+ if self .simOp :
304+ x = self .Op @ x
305+ for iwin0 in range (self .dims [0 ]):
306+ if self .simOp :
307+ xxwin = x [iwin0 ].reshape (self .nwin , self .dimsd [- 1 ])
308+ else :
309+ xxwin = self .Op .matvec (x [iwin0 ].ravel ()).reshape (
310+ self .nwin , self .dimsd [- 1 ]
311+ )
312+ if self .tapertype is not None :
313+ if iwin0 == 0 :
314+ xxwin = self .taps [0 ] * xxwin
315+ elif iwin0 == self .dims [0 ] - 1 :
316+ xxwin = self .taps [- 1 ] * xxwin
317+ else :
318+ xxwin = self .taps [1 ] * xxwin
319+ y [self .dwin_inends [0 ][iwin0 ] : self .dwin_inends [1 ][iwin0 ]] += xxwin
320+ return y
321+
322+ @reshaped
323+ def _rmatvec_nosavetaper (self , x : NDArray ) -> NDArray :
324+ ncp = get_array_module (x )
325+ ncp_sliding_window_view = get_sliding_window_view (x )
326+ if self .tapertype is not None :
327+ self .taps = to_cupy_conditional (x , self .taps )
328+ ywins = (
329+ ncp_sliding_window_view (x , self .nwin , axis = 0 )[:: self .nwin - self .nover ]
330+ .transpose (0 , 2 , 1 )
331+ .copy ()
332+ )
333+ if self .simOp :
334+ if self .tapertype is not None :
335+ for iwin0 in range (self .dims [0 ]):
336+ if iwin0 == 0 :
337+ ywins [0 ] = ywins [0 ] * self .taps [0 ]
338+ elif iwin0 == self .dims [0 ] - 1 :
339+ ywins [- 1 ] = ywins [- 1 ] * self .taps [- 1 ]
340+ else :
341+ ywins [iwin0 ] = ywins [iwin0 ] * self .taps [1 ]
342+ y = self .Op .H @ ywins
343+ else :
344+ y = ncp .zeros (self .dims , dtype = self .dtype )
345+ for iwin0 in range (self .dims [0 ]):
346+ if iwin0 == 0 :
347+ if self .tapertype is not None :
348+ ywins [0 ] = ywins [0 ] * self .taps [0 ]
349+ y [0 ] = self .Op .rmatvec (ywins [0 ].ravel ()).reshape (
350+ self .dims [1 ], self .dims [2 ]
351+ )
352+ elif iwin0 == self .dims [0 ] - 1 :
353+ if self .tapertype is not None :
354+ ywins [- 1 ] = ywins [- 1 ] * self .taps [- 1 ]
355+ y [- 1 ] = self .Op .rmatvec (ywins [- 1 ].ravel ()).reshape (
356+ self .dims [1 ], self .dims [2 ]
357+ )
358+ else :
359+ if self .tapertype is not None :
360+ ywins [iwin0 ] = ywins [iwin0 ] * self .taps [1 ]
361+ y [iwin0 ] = self .Op .rmatvec (ywins [iwin0 ].ravel ()).reshape (
362+ self .dims [1 ], self .dims [2 ]
363+ )
364+ return y
365+
366+ def _register_multiplications (self , savetaper : bool ) -> None :
367+ if savetaper :
368+ self ._matvec = self ._matvec_savetaper
369+ self ._rmatvec = self ._rmatvec_savetaper
370+ else :
371+ self ._matvec = self ._matvec_nosavetaper
372+ self ._rmatvec = self ._rmatvec_nosavetaper
0 commit comments