@@ -297,6 +297,27 @@ def __init__(
297297
298298 self ._register_multiplications (self .savetaper )
299299
300+ def _apply_taper (self , ywins , iwin0 , iwin1 ):
301+ if iwin0 == 0 and iwin1 == 0 :
302+ ywins [0 , 0 ] = self .taps [0 , 0 ] * ywins [0 , 0 ]
303+ elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
304+ ywins [0 , - 1 ] = self .taps [0 , - 1 ] * ywins [0 , - 1 ]
305+ elif iwin0 == 0 :
306+ ywins [0 , iwin1 ] = self .taps [0 , 1 ] * ywins [0 , iwin1 ]
307+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
308+ ywins [- 1 , 0 ] = self .taps [- 1 , 0 ] * ywins [- 1 , 0 ]
309+ elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
310+ ywins [- 1 , - 1 ] = self .taps [- 1 , - 1 ] * ywins [- 1 , - 1 ]
311+ elif iwin0 == self .dims [0 ] - 1 :
312+ ywins [- 1 , iwin1 ] = self .taps [- 1 , 1 ] * ywins [- 1 , iwin1 ]
313+ elif iwin1 == 0 :
314+ ywins [iwin0 , 0 ] = self .taps [1 , 0 ] * ywins [iwin0 , 0 ]
315+ elif iwin1 == self .dims [1 ] - 1 :
316+ ywins [iwin0 , - 1 ] = self .taps [1 , - 1 ] * ywins [iwin0 , - 1 ]
317+ else :
318+ ywins [iwin0 , iwin1 ] = self .taps [1 , 1 ] * ywins [iwin0 , iwin1 ]
319+ return ywins
320+
300321 @reshaped ()
301322 def _matvec_savetaper (self , x : NDArray ) -> NDArray :
302323 ncp = get_array_module (x )
@@ -397,48 +418,14 @@ def _rmatvec_nosavetaper(self, x: NDArray) -> NDArray:
397418 if self .tapertype is not None :
398419 for iwin0 in range (self .dims [0 ]):
399420 for iwin1 in range (self .dims [1 ]):
400- if iwin0 == 0 and iwin1 == 0 :
401- ywins [0 , 0 ] = self .taps [0 , 0 ] * ywins [0 , 0 ]
402- elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
403- ywins [0 , - 1 ] = self .taps [0 , - 1 ] * ywins [0 , - 1 ]
404- elif iwin0 == 0 :
405- ywins [0 , iwin1 ] = self .taps [0 , 1 ] * ywins [0 , iwin1 ]
406- elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
407- ywins [- 1 , 0 ] = self .taps [- 1 , 0 ] * ywins [- 1 , 0 ]
408- elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
409- ywins [- 1 , - 1 ] = self .taps [- 1 , - 1 ] * ywins [- 1 , - 1 ]
410- elif iwin0 == self .dims [0 ] - 1 :
411- ywins [- 1 , iwin1 ] = self .taps [- 1 , 1 ] * ywins [- 1 , iwin1 ]
412- elif iwin1 == 0 :
413- ywins [iwin0 , 0 ] = self .taps [1 , 0 ] * ywins [iwin0 , 0 ]
414- elif iwin1 == self .dims [1 ] - 1 :
415- ywins [iwin0 , - 1 ] = self .taps [1 , - 1 ] * ywins [iwin0 , - 1 ]
416- else :
417- ywins [iwin0 , iwin1 ] = self .taps [1 , 1 ] * ywins [iwin0 , iwin1 ]
421+ ywins = self ._apply_taper (ywins , iwin0 , iwin1 )
418422 y = self .Op .H @ ywins
419423 else :
420424 y = ncp .zeros (self .dims , dtype = self .dtype )
421425 for iwin0 in range (self .dims [0 ]):
422426 for iwin1 in range (self .dims [1 ]):
423427 if self .tapertype is not None :
424- if iwin0 == 0 and iwin1 == 0 :
425- ywins [0 , 0 ] = self .taps [0 , 0 ] * ywins [0 , 0 ]
426- elif iwin0 == 0 and iwin1 == self .dims [1 ] - 1 :
427- ywins [0 , - 1 ] = self .taps [0 , - 1 ] * ywins [0 , - 1 ]
428- elif iwin0 == 0 :
429- ywins [0 , iwin1 ] = self .taps [0 , 1 ] * ywins [0 , iwin1 ]
430- elif iwin0 == self .dims [0 ] - 1 and iwin1 == 0 :
431- ywins [- 1 , 0 ] = self .taps [- 1 , 0 ] * ywins [- 1 , 0 ]
432- elif iwin0 == self .dims [0 ] - 1 and iwin1 == self .dims [1 ] - 1 :
433- ywins [- 1 , - 1 ] = self .taps [- 1 , - 1 ] * ywins [- 1 , - 1 ]
434- elif iwin0 == self .dims [0 ] - 1 :
435- ywins [- 1 , iwin1 ] = self .taps [- 1 , 1 ] * ywins [- 1 , iwin1 ]
436- elif iwin1 == 0 :
437- ywins [iwin0 , 0 ] = self .taps [1 , 0 ] * ywins [iwin0 , 0 ]
438- elif iwin1 == self .dims [1 ] - 1 :
439- ywins [iwin0 , - 1 ] = self .taps [1 , - 1 ] * ywins [iwin0 , - 1 ]
440- else :
441- ywins [iwin0 , iwin1 ] = self .taps [1 , 1 ] * ywins [iwin0 , iwin1 ]
428+ ywins = self ._apply_taper (ywins , iwin0 , iwin1 )
442429 y [iwin0 , iwin1 ] = self .Op .rmatvec (
443430 ywins [iwin0 , iwin1 ].ravel ()
444431 ).reshape (self .dims [2 ], self .dims [3 ])
0 commit comments