Skip to content

Commit 9cabdf4

Browse files
committed
feature: improved Patch2D
Changed patch2d from using other operators to being directly implemented (including also an option to apply Op simultaneously to all windows).
1 parent 2df59f0 commit 9cabdf4

File tree

2 files changed

+155
-109
lines changed

2 files changed

+155
-109
lines changed

pylops/signalprocessing/patch2d.py

Lines changed: 154 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99
import numpy as np
1010

1111
from pylops import LinearOperator
12-
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1312
from pylops.signalprocessing.sliding2d import _slidingsteps
13+
from pylops.utils._internal import _value_or_sized_to_tuple
14+
from pylops.utils.backend import (
15+
get_array_module,
16+
get_sliding_window_view,
17+
to_cupy_conditional,
18+
)
19+
from pylops.utils.decorators import reshaped
1420
from pylops.utils.tapers import taper2d
1521
from pylops.utils.typing import InputDimsLike, NDArray
1622

@@ -91,17 +97,7 @@ def patch2d_design(
9197
return nwins, dims, mwins_inends, dwins_inends
9298

9399

94-
def Patch2D(
95-
Op: LinearOperator,
96-
dims: InputDimsLike,
97-
dimsd: InputDimsLike,
98-
nwin: Tuple[int, int],
99-
nover: Tuple[int, int],
100-
nop: Tuple[int, int],
101-
tapertype: str = "hanning",
102-
scalings: Optional[Sequence[float]] = None,
103-
name: str = "P",
104-
) -> LinearOperator:
100+
class Patch2D(LinearOperator):
105101
"""2D Patch transform operator.
106102
107103
Apply a transform operator ``Op`` repeatedly to patches of the model
@@ -172,104 +168,154 @@ def Patch2D(
172168
Patch3D: 3D Patching transform operator.
173169
174170
"""
175-
# data windows
176-
dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
177-
dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
178-
nwins0 = len(dwin0_ins)
179-
nwins1 = len(dwin1_ins)
180-
nwins = nwins0 * nwins1
181-
182-
# check patching
183-
if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]:
184-
raise ValueError(
185-
f"Model shape (dims={dims}) is not consistent with chosen "
186-
f"number of windows. Run patch2d_design to identify the "
187-
f"correct number of windows for the current "
188-
"model size..."
189-
)
190171

191-
# create tapers
192-
if tapertype is not None:
193-
tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype)
194-
taps = {itap: tap for itap in range(nwins)}
195-
# topmost tapers
196-
taptop = tap.copy()
197-
taptop[: nover[0]] = tap[nwin[0] // 2]
198-
for itap in range(0, nwins1):
199-
taps[itap] = taptop
200-
# bottommost tapers
201-
tapbottom = tap.copy()
202-
tapbottom[-nover[0] :] = tap[nwin[0] // 2]
203-
for itap in range(nwins - nwins1, nwins):
204-
taps[itap] = tapbottom
205-
# leftmost tapers
206-
tapleft = tap.copy()
207-
tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
208-
for itap in range(0, nwins, nwins1):
209-
taps[itap] = tapleft
210-
# rightmost tapers
211-
tapright = tap.copy()
212-
tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
213-
for itap in range(nwins1 - 1, nwins, nwins1):
214-
taps[itap] = tapright
215-
# lefttopcorner taper
216-
taplefttop = tap.copy()
217-
taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
218-
taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
219-
taps[0] = taplefttop
220-
# righttopcorner taper
221-
taprighttop = tap.copy()
222-
taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
223-
taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
224-
taps[nwins1 - 1] = taprighttop
225-
# leftbottomcorner taper
226-
tapleftbottom = tap.copy()
227-
tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
228-
tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
229-
taps[nwins - nwins1] = tapleftbottom
230-
# rightbottomcorner taper
231-
taprightbottom = tap.copy()
232-
taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
233-
taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
234-
taps[nwins - 1] = taprightbottom
235-
236-
# define scalings
237-
if scalings is None:
238-
scalings = [1.0] * nwins
239-
240-
# transform to apply
241-
if tapertype is None:
242-
OOp = BlockDiag([scalings[itap] * Op for itap in range(nwins)])
243-
else:
244-
OOp = BlockDiag(
245-
[
246-
scalings[itap] * Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op
247-
for itap in range(nwins)
248-
]
172+
def __init__(
173+
self,
174+
Op: LinearOperator,
175+
dims: InputDimsLike,
176+
dimsd: InputDimsLike,
177+
nwin: Tuple[int, int],
178+
nover: Tuple[int, int],
179+
nop: Tuple[int, int],
180+
tapertype: str = "hanning",
181+
scalings: Optional[Sequence[float]] = None,
182+
name: str = "P",
183+
) -> None:
184+
185+
dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims)
186+
dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd)
187+
188+
# data windows
189+
dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0])
190+
dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1])
191+
self.dwins_inends = ((dwin0_ins, dwin0_ends), (dwin1_ins, dwin1_ends))
192+
nwins0 = len(dwin0_ins)
193+
nwins1 = len(dwin1_ins)
194+
nwins = nwins0 * nwins1
195+
self.nwin = nwin
196+
self.nover = nover
197+
198+
# check patching
199+
if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]:
200+
raise ValueError(
201+
f"Model shape (dims={dims}) is not consistent with chosen "
202+
f"number of windows. Run patch2d_design to identify the "
203+
f"correct number of windows for the current "
204+
"model size..."
205+
)
206+
207+
# create tapers
208+
self.tapertype = tapertype
209+
if self.tapertype is not None:
210+
tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype)
211+
taps = [
212+
tap,
213+
] * nwins
214+
# topmost tapers
215+
taptop = tap.copy()
216+
taptop[: nover[0]] = tap[nwin[0] // 2]
217+
for itap in range(0, nwins1):
218+
taps[itap] = taptop
219+
# bottommost tapers
220+
tapbottom = tap.copy()
221+
tapbottom[-nover[0] :] = tap[nwin[0] // 2]
222+
for itap in range(nwins - nwins1, nwins):
223+
taps[itap] = tapbottom
224+
# leftmost tapers
225+
tapleft = tap.copy()
226+
tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
227+
for itap in range(0, nwins, nwins1):
228+
taps[itap] = tapleft
229+
# rightmost tapers
230+
tapright = tap.copy()
231+
tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
232+
for itap in range(nwins1 - 1, nwins, nwins1):
233+
taps[itap] = tapright
234+
# lefttopcorner taper
235+
taplefttop = tap.copy()
236+
taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
237+
taplefttop[: nover[0]] = taplefttop[nwin[0] // 2]
238+
taps[0] = taplefttop
239+
# righttopcorner taper
240+
taprighttop = tap.copy()
241+
taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
242+
taprighttop[: nover[0]] = taprighttop[nwin[0] // 2]
243+
taps[nwins1 - 1] = taprighttop
244+
# leftbottomcorner taper
245+
tapleftbottom = tap.copy()
246+
tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis]
247+
tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2]
248+
taps[nwins - nwins1] = tapleftbottom
249+
# rightbottomcorner taper
250+
taprightbottom = tap.copy()
251+
taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis]
252+
taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2]
253+
taps[nwins - 1] = taprightbottom
254+
self.taps = np.vstack(taps).reshape(nwins0, nwins1, nwin[0], nwin[1])
255+
256+
# define scalings
257+
if scalings is None:
258+
self.scalings = [1.0] * nwins
259+
else:
260+
self.scalings = scalings
261+
262+
# check if operator is applied to all windows simultaneously
263+
self.simOp = False
264+
if Op.shape[1] == np.prod(dims):
265+
self.simOp = True
266+
self.Op = Op
267+
268+
super().__init__(
269+
dtype=Op.dtype,
270+
dims=(nwins0, nwins1, int(dims[0] // nwins0), int(dims[1] // nwins1)),
271+
dimsd=dimsd,
272+
clinear=False,
273+
name=name,
249274
)
250275

251-
hstack = HStack(
252-
[
253-
Restriction(
254-
(nwin[0], dimsd[1]), range(win_in, win_end), axis=1, dtype=Op.dtype
255-
).H
256-
for win_in, win_end in zip(dwin1_ins, dwin1_ends)
257-
]
258-
)
259-
combining1 = BlockDiag([hstack] * nwins0)
276+
@reshaped()
277+
def _matvec(self, x: NDArray) -> NDArray:
278+
ncp = get_array_module(x)
279+
if self.tapertype is not None:
280+
self.taps = to_cupy_conditional(x, self.taps)
281+
y = ncp.zeros(self.dimsd, dtype=self.dtype)
282+
if self.simOp:
283+
x = self.Op @ x
284+
for iwin0 in range(self.dims[0]):
285+
for iwin1 in range(self.dims[1]):
286+
if self.simOp:
287+
xx = x[iwin0, iwin1].reshape(self.nwin)
288+
else:
289+
xx = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(self.nwin)
290+
if self.tapertype is not None:
291+
xxwin = self.taps[iwin0, iwin1] * xx
292+
else:
293+
xxwin = xx
260294

261-
combining0 = HStack(
262-
[
263-
Restriction(dimsd, range(win_in, win_end), axis=0, dtype=Op.dtype).H
264-
for win_in, win_end in zip(dwin0_ins, dwin0_ends)
295+
y[
296+
self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0],
297+
self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1],
298+
] += xxwin
299+
return y
300+
301+
@reshaped
302+
def _rmatvec(self, x: NDArray) -> NDArray:
303+
ncp = get_array_module(x)
304+
ncp_sliding_window_view = get_sliding_window_view(x)
305+
if self.tapertype is not None:
306+
self.taps = to_cupy_conditional(x, self.taps)
307+
ywins = ncp_sliding_window_view(x, self.nwin)[
308+
:: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1]
265309
]
266-
)
267-
Pop = LinearOperator(combining0 * combining1 * OOp)
268-
Pop.dims, Pop.dimsd = (
269-
nwins0,
270-
nwins1,
271-
int(dims[0] // nwins0),
272-
int(dims[1] // nwins1),
273-
), dimsd
274-
Pop.name = name
275-
return Pop
310+
if self.tapertype is not None:
311+
ywins = ywins * self.taps
312+
if self.simOp:
313+
y = self.Op.H @ ywins
314+
else:
315+
y = ncp.zeros(self.dims, dtype=self.dtype)
316+
for iwin0 in range(self.dims[0]):
317+
for iwin1 in range(self.dims[1]):
318+
y[iwin0, iwin1] = self.Op.rmatvec(
319+
ywins[iwin0, iwin1].ravel()
320+
).reshape(self.dims[2], self.dims[3])
321+
return y

pylops/signalprocessing/sliding3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def __init__(
218218
tap = taper3d(dimsd[2], nwin, nover, tapertype=tapertype).astype(Op.dtype)
219219
taps = [
220220
tap,
221-
] * nwins # {itap: tap for itap in range(nwins)}
221+
] * nwins
222222

223223
# topmost tapers
224224
taptop = tap.copy()

0 commit comments

Comments
 (0)