Skip to content

Commit 493e4d4

Browse files
committed
feature: improved Sliding1d
Changed sliding1d from using other operators to being directly implemented (including also an option to apply Op simultaneously to all windows).
1 parent 42c6b2f commit 493e4d4

File tree

2 files changed

+138
-59
lines changed

2 files changed

+138
-59
lines changed

pylops/signalprocessing/sliding1d.py

Lines changed: 114 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,18 @@
66
import logging
77
from typing import Tuple, Union
88

9+
import numpy as np
10+
from numpy.lib.stride_tricks import sliding_window_view
11+
912
from pylops import LinearOperator
10-
from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction
1113
from pylops.signalprocessing.sliding2d import _slidingsteps
1214
from pylops.utils._internal import _value_or_sized_to_tuple
15+
from pylops.utils.backend import (
16+
get_array_module,
17+
get_sliding_window_view,
18+
to_cupy_conditional,
19+
)
20+
from pylops.utils.decorators import reshaped
1321
from pylops.utils.tapers import taper
1422
from pylops.utils.typing import InputDimsLike, NDArray
1523

@@ -77,15 +85,7 @@ def sliding1d_design(
7785
return nwins, dim, mwins_inends, dwins_inends
7886

7987

80-
def Sliding1D(
81-
Op: LinearOperator,
82-
dim: Union[int, InputDimsLike],
83-
dimd: Union[int, InputDimsLike],
84-
nwin: int,
85-
nover: int,
86-
tapertype: str = "hanning",
87-
name: str = "S",
88-
) -> LinearOperator:
88+
class Sliding1D(LinearOperator):
8989
r"""1D Sliding transform operator.
9090
9191
Apply a transform operator ``Op`` repeatedly to slices of the model
@@ -103,6 +103,12 @@ def Sliding1D(
103103
``nover``, it is recommended to first run ``sliding1d_design`` to obtain
104104
the corresponding ``dims`` and number of windows.
105105
106+
.. note:: Two kind of operators ``Op`` can be provided: the first
107+
applies a single transformation to each window separately; the second
108+
applies the transformation to all of the windows at the same time. This
109+
is directly inferred during initialization when the following condition
110+
holds ``Op.shape[1] == dim[0]``.
111+
106112
.. warning:: Depending on the choice of `nwin` and `nover` as well as the
107113
size of the data, sliding windows may not cover the entire data.
108114
The start and end indices of each window will be displayed and returned
@@ -127,62 +133,111 @@ def Sliding1D(
127133
128134
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
129135
130-
Returns
131-
-------
132-
Sop : :obj:`pylops.LinearOperator`
133-
Sliding operator
134-
135136
Raises
136137
------
137138
ValueError
138139
Identified number of windows is not consistent with provided model
139140
shape (``dims``).
140141
141142
"""
142-
dim: Tuple[int, ...] = _value_or_sized_to_tuple(dim)
143-
dimd: Tuple[int, ...] = _value_or_sized_to_tuple(dimd)
144143

145-
# data windows
146-
dwin_ins, dwin_ends = _slidingsteps(dimd[0], nwin, nover)
147-
nwins = len(dwin_ins)
148-
149-
# check windows
150-
if nwins * Op.shape[1] != dim[0]:
151-
raise ValueError(
152-
f"Model shape (dim={dim}) is not consistent with chosen "
153-
f"number of windows. Run sliding1d_design to identify the "
154-
f"correct number of windows for the current "
155-
"model size..."
144+
def __init__(
145+
self,
146+
Op: LinearOperator,
147+
dim: Union[int, InputDimsLike],
148+
dimd: Union[int, InputDimsLike],
149+
nwin: int,
150+
nover: int,
151+
tapertype: str = "hanning",
152+
name: str = "S",
153+
) -> None:
154+
155+
dim: Tuple[int, ...] = _value_or_sized_to_tuple(dim)
156+
dimd: Tuple[int, ...] = _value_or_sized_to_tuple(dimd)
157+
158+
# data windows
159+
dwin_ins, dwin_ends = _slidingsteps(dimd[0], nwin, nover)
160+
self.dwin_inends = (dwin_ins, dwin_ends)
161+
nwins = len(dwin_ins)
162+
self.nwin = nwin
163+
self.nover = nover
164+
165+
# check windows
166+
if nwins * Op.shape[1] != dim[0] and Op.shape[1] != dim[0]:
167+
raise ValueError(
168+
f"Model shape (dim={dim}) is not consistent with chosen "
169+
f"number of windows. Run sliding1d_design to identify the "
170+
f"correct number of windows for the current "
171+
"model size..."
172+
)
173+
174+
# create tapers
175+
self.tapertype = tapertype
176+
if self.tapertype is not None:
177+
tap = taper(nwin, nover, tapertype=self.tapertype)
178+
tapin = tap.copy()
179+
tapin[:nover] = 1
180+
tapend = tap.copy()
181+
tapend[-nover:] = 1
182+
self.taps = [
183+
tapin,
184+
]
185+
for i in range(1, nwins - 1):
186+
self.taps.append(tap)
187+
self.taps.append(tapend)
188+
self.taps = np.vstack(self.taps)
189+
190+
# check if operator is applied to all windows simultaneously
191+
self.simOp = False
192+
if Op.shape[1] == dim[0]:
193+
self.simOp = True
194+
self.Op = Op
195+
196+
# create temporary shape and strides for cpy
197+
self.shape_wins = None
198+
self.strides_wins = None
199+
200+
super().__init__(
201+
dtype=Op.dtype,
202+
dims=(nwins, int(dim[0] // nwins)),
203+
dimsd=dimd,
204+
clinear=False,
205+
name=name,
156206
)
157207

158-
# create tapers
159-
if tapertype is not None:
160-
tap = taper(nwin, nover, tapertype=tapertype).astype(Op.dtype)
161-
tapin = tap.copy()
162-
tapin[:nover] = 1
163-
tapend = tap.copy()
164-
tapend[-nover:] = 1
165-
taps = {}
166-
taps[0] = tapin
167-
for i in range(1, nwins - 1):
168-
taps[i] = tap
169-
taps[nwins - 1] = tapend
170-
171-
# transform to apply
172-
if tapertype is None:
173-
OOp = BlockDiag([Op for _ in range(nwins)])
174-
else:
175-
OOp = BlockDiag(
176-
[Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op for itap in range(nwins)]
177-
)
178-
179-
combining = HStack(
180-
[
181-
Restriction(dimd, range(win_in, win_end), dtype=Op.dtype).H
182-
for win_in, win_end in zip(dwin_ins, dwin_ends)
183-
]
184-
)
185-
Sop = LinearOperator(combining * OOp)
186-
Sop.dims, Sop.dimsd = (nwins, int(dim[0] // nwins)), dimd
187-
Sop.name = name
188-
return Sop
208+
@reshaped
209+
def _matvec(self, x: NDArray) -> NDArray:
210+
ncp = get_array_module(x)
211+
if self.tapertype is not None:
212+
self.taps = to_cupy_conditional(x, self.taps)
213+
y = ncp.zeros(self.dimsd, dtype=self.dtype)
214+
if self.simOp:
215+
x = self.Op @ x
216+
for iwin0 in range(self.dims[0]):
217+
if self.simOp:
218+
xx = x[iwin0]
219+
else:
220+
xx = self.Op.matvec(x[iwin0])
221+
if self.tapertype is not None:
222+
xxwin = self.taps[iwin0] * xx
223+
else:
224+
xxwin = xx
225+
y[self.dwin_inends[0][iwin0] : self.dwin_inends[1][iwin0]] += xxwin
226+
return y
227+
228+
@reshaped
229+
def _rmatvec(self, x: NDArray) -> NDArray:
230+
ncp = get_array_module(x)
231+
ncp_sliding_window_view = get_sliding_window_view(x)
232+
if self.tapertype is not None:
233+
self.taps = to_cupy_conditional(x, self.taps)
234+
ywins = ncp_sliding_window_view(x, self.nwin)[:: self.nwin - self.nover]
235+
if self.tapertype is not None:
236+
ywins = ywins * self.taps
237+
if self.simOp:
238+
y = self.Op.H @ ywins
239+
else:
240+
y = ncp.zeros(self.dims, dtype=self.dtype)
241+
for iwin0 in range(self.dims[0]):
242+
y[iwin0] = self.Op.rmatvec(ywins[iwin0])
243+
return y

pylops/utils/backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"get_oaconvolve",
88
"get_correlate",
99
"get_add_at",
10+
"get_sliding_window_view",
1011
"get_block_diag",
1112
"get_toeplitz",
1213
"get_csc_matrix",
@@ -228,6 +229,29 @@ def get_add_at(x: npt.ArrayLike) -> Callable:
228229
return cupyx.scatter_add
229230

230231

232+
def get_sliding_window_view(x: npt.ArrayLike) -> Callable:
233+
"""Returns correct sliding_window_view module based on input
234+
235+
Parameters
236+
----------
237+
x : :obj:`numpy.ndarray`
238+
Array
239+
240+
Returns
241+
-------
242+
mod : :obj:`func`
243+
Module to be used to process array (:mod:`numpy` or :mod:`cupy`)
244+
245+
"""
246+
if not deps.cupy_enabled:
247+
return np.lib.stride_tricks.sliding_window_view
248+
249+
if cp.get_array_module(x) == np:
250+
return np.lib.stride_tricks.sliding_window_view
251+
else:
252+
return cp.lib.stride_tricks.sliding_window_view
253+
254+
231255
def get_block_diag(x: npt.ArrayLike) -> Callable:
232256
"""Returns correct block_diag module based on input
233257

0 commit comments

Comments
 (0)