@@ -141,6 +141,9 @@ class Patch2D(LinearOperator):
141141 Size of model in the transformed domain
142142 tapertype : :obj:`str`, optional
143143 Type of taper (``hanning``, ``cosine``, ``cosinesquare`` or ``None``)
144+ savetaper: :obj:`bool`, optional
145+ Save all tapers and apply them in one go (``True``) or save unique tapers and apply them one by one (``False``).
146+ The first option is more computationally efficient, whilst the second is more memory efficient.
144147 scalings : :obj:`tuple` or :obj:`list`, optional
145148 Set of scalings to apply to each patch. If ``None``, no scale will be
146149 applied
@@ -178,6 +181,7 @@ def __init__(
178181 nover : Tuple [int , int ],
179182 nop : Tuple [int , int ],
180183 tapertype : str = "hanning" ,
184+ savetaper : bool = True ,
181185 scalings : Optional [Sequence [float ]] = None ,
182186 name : str = "P" ,
183187 ) -> None :
@@ -206,52 +210,68 @@ def __init__(
206210
207211 # create tapers
208212 self .tapertype = tapertype
213+ self .savetaper = savetaper
209214 if self .tapertype is not None :
210215 tap = taper2d (nwin [1 ], nwin [0 ], nover , tapertype = tapertype ).astype (Op .dtype )
211- taps = [
212- tap ,
213- ] * nwins
214216 # topmost tapers
215217 taptop = tap .copy ()
216218 taptop [: nover [0 ]] = tap [nwin [0 ] // 2 ]
217- for itap in range (0 , nwins1 ):
218- taps [itap ] = taptop
219219 # bottommost tapers
220220 tapbottom = tap .copy ()
221221 tapbottom [- nover [0 ] :] = tap [nwin [0 ] // 2 ]
222- for itap in range (nwins - nwins1 , nwins ):
223- taps [itap ] = tapbottom
224222 # leftmost tapers
225223 tapleft = tap .copy ()
226224 tapleft [:, : nover [1 ]] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
227- for itap in range (0 , nwins , nwins1 ):
228- taps [itap ] = tapleft
229225 # rightmost tapers
230226 tapright = tap .copy ()
231227 tapright [:, - nover [1 ] :] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
232- for itap in range (nwins1 - 1 , nwins , nwins1 ):
233- taps [itap ] = tapright
234228 # lefttopcorner taper
235229 taplefttop = tap .copy ()
236230 taplefttop [:, : nover [1 ]] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
237231 taplefttop [: nover [0 ]] = taplefttop [nwin [0 ] // 2 ]
238- taps [0 ] = taplefttop
239232 # righttopcorner taper
240233 taprighttop = tap .copy ()
241234 taprighttop [:, - nover [1 ] :] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
242235 taprighttop [: nover [0 ]] = taprighttop [nwin [0 ] // 2 ]
243- taps [nwins1 - 1 ] = taprighttop
244236 # leftbottomcorner taper
245237 tapleftbottom = tap .copy ()
246238 tapleftbottom [:, : nover [1 ]] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
247239 tapleftbottom [- nover [0 ] :] = tapleftbottom [nwin [0 ] // 2 ]
248- taps [nwins - nwins1 ] = tapleftbottom
249240 # rightbottomcorner taper
250241 taprightbottom = tap .copy ()
251242 taprightbottom [:, - nover [1 ] :] = tap [:, nwin [1 ] // 2 ][:, np .newaxis ]
252243 taprightbottom [- nover [0 ] :] = taprightbottom [nwin [0 ] // 2 ]
253- taps [nwins - 1 ] = taprightbottom
254- self .taps = np .vstack (taps ).reshape (nwins0 , nwins1 , nwin [0 ], nwin [1 ])
244+
245+ if self .savetaper :
246+ taps = [
247+ tap ,
248+ ] * nwins
249+ for itap in range (0 , nwins1 ):
250+ taps [itap ] = taptop
251+ for itap in range (nwins - nwins1 , nwins ):
252+ taps [itap ] = tapbottom
253+ for itap in range (0 , nwins , nwins1 ):
254+ taps [itap ] = tapleft
255+ for itap in range (nwins1 - 1 , nwins , nwins1 ):
256+ taps [itap ] = tapright
257+ taps [0 ] = taplefttop
258+ taps [nwins1 - 1 ] = taprighttop
259+ taps [nwins - nwins1 ] = tapleftbottom
260+ taps [nwins - 1 ] = taprightbottom
261+ self .taps = np .vstack (taps ).reshape (nwins0 , nwins1 , nwin [0 ], nwin [1 ])
262+ else :
263+ taps = [
264+ taplefttop ,
265+ taptop ,
266+ taprighttop ,
267+ tapleft ,
268+ tap ,
269+ tapright ,
270+ tapleftbottom ,
271+ tapbottom ,
272+ taprightbottom ,
273+ ]
274+ self .taps = np .vstack (taps ).reshape (3 , 3 , nwin [0 ], nwin [1 ])
255275
256276 # define scalings
257277 if scalings is None :
@@ -273,8 +293,10 @@ def __init__(
273293 name = name ,
274294 )
275295
296+ self ._register_multiplications (self .savetaper )
297+
276298 @reshaped ()
277- def _matvec (self , x : NDArray ) -> NDArray :
299+ def _matvec_savetaper (self , x : NDArray ) -> NDArray :
278300 ncp = get_array_module (x )
279301 if self .tapertype is not None :
280302 self .taps = to_cupy_conditional (x , self .taps )
@@ -299,7 +321,7 @@ def _matvec(self, x: NDArray) -> NDArray:
299321 return y
300322
301323 @reshaped
302- def _rmatvec (self , x : NDArray ) -> NDArray :
324+ def _rmatvec_savetaper (self , x : NDArray ) -> NDArray :
303325 ncp = get_array_module (x )
304326 ncp_sliding_window_view = get_sliding_window_view (x )
305327 if self .tapertype is not None :
@@ -319,3 +341,111 @@ def _rmatvec(self, x: NDArray) -> NDArray:
319341 ywins [iwin0 , iwin1 ].ravel ()
320342 ).reshape (self .dims [2 ], self .dims [3 ])
321343 return y
344+
345+ @reshaped ()
346+ def _matvec_nosavetaper (self , x : NDArray ) -> NDArray :
347+ ncp = get_array_module (x )
348+ if self .tapertype is not None :
349+ self .taps = to_cupy_conditional (x , self .taps )
350+ y = ncp .zeros (self .dimsd , dtype = self .dtype )
351+ if self .simOp :
352+ x = self .Op @ x
353+ for iwin0 in range (self .dims [0 ]):
354+ for iwin1 in range (self .dims [1 ]):
355+ if self .simOp :
356+ xxwin = x [iwin0 , iwin1 ].reshape (self .nwin )
357+ else :
358+ xxwin = self .Op .matvec (x [iwin0 , iwin1 ].ravel ()).reshape (self .nwin )
359+ if self .tapertype is not None :
360+ if iwin0 == 0 and iwin1 == 0 :
361+ xxwin = self .taps [0 , 0 ] * xxwin
362+ elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
363+ xxwin = self .taps [0 , - 1 ] * xxwin
364+ elif iwin0 == 0 :
365+ xxwin = self .taps [0 , 1 ] * xxwin
366+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
367+ xxwin = self .taps [- 1 , 0 ] * xxwin
368+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
369+ xxwin = self .taps [- 1 , - 1 ] * xxwin
370+ elif iwin0 == self .dims [0 ] - 1 :
371+ xxwin = self .taps [- 1 , 1 ] * xxwin
372+ elif iwin1 == 0 :
373+ xxwin = self .taps [1 , 0 ] * xxwin
374+ elif iwin1 == self .dims [1 ] - 1 :
375+ xxwin = self .taps [1 , - 1 ] * xxwin
376+ else :
377+ xxwin = self .taps [1 , 1 ] * xxwin
378+
379+ y [
380+ self .dwins_inends [0 ][0 ][iwin0 ] : self .dwins_inends [0 ][1 ][iwin0 ],
381+ self .dwins_inends [1 ][0 ][iwin1 ] : self .dwins_inends [1 ][1 ][iwin1 ],
382+ ] += xxwin
383+ return y
384+
385+ @reshaped
386+ def _rmatvec_nosavetaper (self , x : NDArray ) -> NDArray :
387+ ncp = get_array_module (x )
388+ ncp_sliding_window_view = get_sliding_window_view (x )
389+ if self .tapertype is not None :
390+ self .taps = to_cupy_conditional (x , self .taps )
391+ ywins = ncp_sliding_window_view (x , self .nwin )[
392+ :: self .nwin [0 ] - self .nover [0 ], :: self .nwin [1 ] - self .nover [1 ]
393+ ].copy ()
394+ if self .simOp :
395+ if self .tapertype is not None :
396+ for iwin0 in range (self .dims [0 ]):
397+ for iwin1 in range (self .dims [1 ]):
398+ if iwin0 == 0 and iwin1 == 0 :
399+ ywins [0 , 0 ] = self .taps [0 , 0 ] * ywins [0 , 0 ]
400+ elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
401+ ywins [0 , - 1 ] = self .taps [0 , - 1 ] * ywins [0 , - 1 ]
402+ elif iwin0 == 0 :
403+ ywins [0 , iwin1 ] = self .taps [0 , 1 ] * ywins [0 , iwin1 ]
404+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
405+ ywins [- 1 , 0 ] = self .taps [- 1 , 0 ] * ywins [- 1 , 0 ]
406+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
407+ ywins [- 1 , - 1 ] = self .taps [- 1 , - 1 ] * ywins [- 1 , - 1 ]
408+ elif iwin0 == self .dims [0 ] - 1 :
409+ ywins [- 1 , iwin1 ] = self .taps [- 1 , 1 ] * ywins [- 1 , iwin1 ]
410+ elif iwin1 == 0 :
411+ ywins [iwin0 , 0 ] = self .taps [1 , 0 ] * ywins [iwin0 , 0 ]
412+ elif iwin1 == self .dims [1 ] - 1 :
413+ ywins [iwin0 , - 1 ] = self .taps [1 , - 1 ] * ywins [iwin0 , - 1 ]
414+ else :
415+ ywins [iwin0 , iwin1 ] = self .taps [1 , 1 ] * ywins [iwin0 , iwin1 ]
416+ y = self .Op .H @ ywins
417+ else :
418+ y = ncp .zeros (self .dims , dtype = self .dtype )
419+ for iwin0 in range (self .dims [0 ]):
420+ for iwin1 in range (self .dims [1 ]):
421+ if self .tapertype is not None :
422+ if iwin0 == 0 and iwin1 == 0 :
423+ ywins [0 , 0 ] = self .taps [0 , 0 ] * ywins [0 , 0 ]
424+ elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
425+ ywins [0 , - 1 ] = self .taps [0 , - 1 ] * ywins [0 , - 1 ]
426+ elif iwin0 == 0 :
427+ ywins [0 , iwin1 ] = self .taps [0 , 1 ] * ywins [0 , iwin1 ]
428+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
429+ ywins [- 1 , 0 ] = self .taps [- 1 , 0 ] * ywins [- 1 , 0 ]
430+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
431+ ywins [- 1 , - 1 ] = self .taps [- 1 , - 1 ] * ywins [- 1 , - 1 ]
432+ elif iwin0 == self .dims [0 ] - 1 :
433+ ywins [- 1 , iwin1 ] = self .taps [- 1 , 1 ] * ywins [- 1 , iwin1 ]
434+ elif iwin1 == 0 :
435+ ywins [iwin0 , 0 ] = self .taps [1 , 0 ] * ywins [iwin0 , 0 ]
436+ elif iwin1 == self .dims [1 ] - 1 :
437+ ywins [iwin0 , - 1 ] = self .taps [1 , - 1 ] * ywins [iwin0 , - 1 ]
438+ else :
439+ ywins [iwin0 , iwin1 ] = self .taps [1 , 1 ] * ywins [iwin0 , iwin1 ]
440+ y [iwin0 , iwin1 ] = self .Op .rmatvec (
441+ ywins [iwin0 , iwin1 ].ravel ()
442+ ).reshape (self .dims [2 ], self .dims [3 ])
443+ return y
444+
445+ def _register_multiplications (self , savetaper : bool ) -> None :
446+ if savetaper :
447+ self ._matvec = self ._matvec_savetaper
448+ self ._rmatvec = self ._rmatvec_savetaper
449+ else :
450+ self ._matvec = self ._matvec_nosavetaper
451+ self ._rmatvec = self ._rmatvec_nosavetaper
0 commit comments