|
9 | 9 | import numpy as np |
10 | 10 |
|
11 | 11 | from pylops import LinearOperator |
12 | | -from pylops.basicoperators import BlockDiag, Diagonal, HStack, Restriction |
13 | 12 | from pylops.signalprocessing.sliding2d import _slidingsteps |
| 13 | +from pylops.utils._internal import _value_or_sized_to_tuple |
| 14 | +from pylops.utils.backend import ( |
| 15 | + get_array_module, |
| 16 | + get_sliding_window_view, |
| 17 | + to_cupy_conditional, |
| 18 | +) |
| 19 | +from pylops.utils.decorators import reshaped |
14 | 20 | from pylops.utils.tapers import taper2d |
15 | 21 | from pylops.utils.typing import InputDimsLike, NDArray |
16 | 22 |
|
@@ -91,17 +97,7 @@ def patch2d_design( |
91 | 97 | return nwins, dims, mwins_inends, dwins_inends |
92 | 98 |
|
93 | 99 |
|
94 | | -def Patch2D( |
95 | | - Op: LinearOperator, |
96 | | - dims: InputDimsLike, |
97 | | - dimsd: InputDimsLike, |
98 | | - nwin: Tuple[int, int], |
99 | | - nover: Tuple[int, int], |
100 | | - nop: Tuple[int, int], |
101 | | - tapertype: str = "hanning", |
102 | | - scalings: Optional[Sequence[float]] = None, |
103 | | - name: str = "P", |
104 | | -) -> LinearOperator: |
| 100 | +class Patch2D(LinearOperator): |
105 | 101 | """2D Patch transform operator. |
106 | 102 |
|
107 | 103 | Apply a transform operator ``Op`` repeatedly to patches of the model |
@@ -172,104 +168,154 @@ def Patch2D( |
172 | 168 | Patch3D: 3D Patching transform operator. |
173 | 169 |
|
174 | 170 | """ |
175 | | - # data windows |
176 | | - dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0]) |
177 | | - dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1]) |
178 | | - nwins0 = len(dwin0_ins) |
179 | | - nwins1 = len(dwin1_ins) |
180 | | - nwins = nwins0 * nwins1 |
181 | | - |
182 | | - # check patching |
183 | | - if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]: |
184 | | - raise ValueError( |
185 | | - f"Model shape (dims={dims}) is not consistent with chosen " |
186 | | - f"number of windows. Run patch2d_design to identify the " |
187 | | - f"correct number of windows for the current " |
188 | | - "model size..." |
189 | | - ) |
190 | 171 |
|
191 | | - # create tapers |
192 | | - if tapertype is not None: |
193 | | - tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype) |
194 | | - taps = {itap: tap for itap in range(nwins)} |
195 | | - # topmost tapers |
196 | | - taptop = tap.copy() |
197 | | - taptop[: nover[0]] = tap[nwin[0] // 2] |
198 | | - for itap in range(0, nwins1): |
199 | | - taps[itap] = taptop |
200 | | - # bottommost tapers |
201 | | - tapbottom = tap.copy() |
202 | | - tapbottom[-nover[0] :] = tap[nwin[0] // 2] |
203 | | - for itap in range(nwins - nwins1, nwins): |
204 | | - taps[itap] = tapbottom |
205 | | - # leftmost tapers |
206 | | - tapleft = tap.copy() |
207 | | - tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
208 | | - for itap in range(0, nwins, nwins1): |
209 | | - taps[itap] = tapleft |
210 | | - # rightmost tapers |
211 | | - tapright = tap.copy() |
212 | | - tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
213 | | - for itap in range(nwins1 - 1, nwins, nwins1): |
214 | | - taps[itap] = tapright |
215 | | - # lefttopcorner taper |
216 | | - taplefttop = tap.copy() |
217 | | - taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
218 | | - taplefttop[: nover[0]] = taplefttop[nwin[0] // 2] |
219 | | - taps[0] = taplefttop |
220 | | - # righttopcorner taper |
221 | | - taprighttop = tap.copy() |
222 | | - taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
223 | | - taprighttop[: nover[0]] = taprighttop[nwin[0] // 2] |
224 | | - taps[nwins1 - 1] = taprighttop |
225 | | - # leftbottomcorner taper |
226 | | - tapleftbottom = tap.copy() |
227 | | - tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
228 | | - tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2] |
229 | | - taps[nwins - nwins1] = tapleftbottom |
230 | | - # rightbottomcorner taper |
231 | | - taprightbottom = tap.copy() |
232 | | - taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
233 | | - taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2] |
234 | | - taps[nwins - 1] = taprightbottom |
235 | | - |
236 | | - # define scalings |
237 | | - if scalings is None: |
238 | | - scalings = [1.0] * nwins |
239 | | - |
240 | | - # transform to apply |
241 | | - if tapertype is None: |
242 | | - OOp = BlockDiag([scalings[itap] * Op for itap in range(nwins)]) |
243 | | - else: |
244 | | - OOp = BlockDiag( |
245 | | - [ |
246 | | - scalings[itap] * Diagonal(taps[itap].ravel(), dtype=Op.dtype) * Op |
247 | | - for itap in range(nwins) |
248 | | - ] |
| 172 | + def __init__( |
| 173 | + self, |
| 174 | + Op: LinearOperator, |
| 175 | + dims: InputDimsLike, |
| 176 | + dimsd: InputDimsLike, |
| 177 | + nwin: Tuple[int, int], |
| 178 | + nover: Tuple[int, int], |
| 179 | + nop: Tuple[int, int], |
| 180 | + tapertype: str = "hanning", |
| 181 | + scalings: Optional[Sequence[float]] = None, |
| 182 | + name: str = "P", |
| 183 | + ) -> None: |
| 184 | + |
| 185 | + dims: Tuple[int, ...] = _value_or_sized_to_tuple(dims) |
| 186 | + dimsd: Tuple[int, ...] = _value_or_sized_to_tuple(dimsd) |
| 187 | + |
| 188 | + # data windows |
| 189 | + dwin0_ins, dwin0_ends = _slidingsteps(dimsd[0], nwin[0], nover[0]) |
| 190 | + dwin1_ins, dwin1_ends = _slidingsteps(dimsd[1], nwin[1], nover[1]) |
| 191 | + self.dwins_inends = ((dwin0_ins, dwin0_ends), (dwin1_ins, dwin1_ends)) |
| 192 | + nwins0 = len(dwin0_ins) |
| 193 | + nwins1 = len(dwin1_ins) |
| 194 | + nwins = nwins0 * nwins1 |
| 195 | + self.nwin = nwin |
| 196 | + self.nover = nover |
| 197 | + |
| 198 | + # check patching |
| 199 | + if nwins0 * nop[0] != dims[0] or nwins1 * nop[1] != dims[1]: |
| 200 | + raise ValueError( |
| 201 | + f"Model shape (dims={dims}) is not consistent with chosen " |
| 202 | + f"number of windows. Run patch2d_design to identify the " |
| 203 | + f"correct number of windows for the current " |
| 204 | + "model size..." |
| 205 | + ) |
| 206 | + |
| 207 | + # create tapers |
| 208 | + self.tapertype = tapertype |
| 209 | + if self.tapertype is not None: |
| 210 | + tap = taper2d(nwin[1], nwin[0], nover, tapertype=tapertype).astype(Op.dtype) |
| 211 | + taps = [ |
| 212 | + tap, |
| 213 | + ] * nwins |
| 214 | + # topmost tapers |
| 215 | + taptop = tap.copy() |
| 216 | + taptop[: nover[0]] = tap[nwin[0] // 2] |
| 217 | + for itap in range(0, nwins1): |
| 218 | + taps[itap] = taptop |
| 219 | + # bottommost tapers |
| 220 | + tapbottom = tap.copy() |
| 221 | + tapbottom[-nover[0] :] = tap[nwin[0] // 2] |
| 222 | + for itap in range(nwins - nwins1, nwins): |
| 223 | + taps[itap] = tapbottom |
| 224 | + # leftmost tapers |
| 225 | + tapleft = tap.copy() |
| 226 | + tapleft[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 227 | + for itap in range(0, nwins, nwins1): |
| 228 | + taps[itap] = tapleft |
| 229 | + # rightmost tapers |
| 230 | + tapright = tap.copy() |
| 231 | + tapright[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 232 | + for itap in range(nwins1 - 1, nwins, nwins1): |
| 233 | + taps[itap] = tapright |
| 234 | + # lefttopcorner taper |
| 235 | + taplefttop = tap.copy() |
| 236 | + taplefttop[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 237 | + taplefttop[: nover[0]] = taplefttop[nwin[0] // 2] |
| 238 | + taps[0] = taplefttop |
| 239 | + # righttopcorner taper |
| 240 | + taprighttop = tap.copy() |
| 241 | + taprighttop[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 242 | + taprighttop[: nover[0]] = taprighttop[nwin[0] // 2] |
| 243 | + taps[nwins1 - 1] = taprighttop |
| 244 | + # leftbottomcorner taper |
| 245 | + tapleftbottom = tap.copy() |
| 246 | + tapleftbottom[:, : nover[1]] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 247 | + tapleftbottom[-nover[0] :] = tapleftbottom[nwin[0] // 2] |
| 248 | + taps[nwins - nwins1] = tapleftbottom |
| 249 | + # rightbottomcorner taper |
| 250 | + taprightbottom = tap.copy() |
| 251 | + taprightbottom[:, -nover[1] :] = tap[:, nwin[1] // 2][:, np.newaxis] |
| 252 | + taprightbottom[-nover[0] :] = taprightbottom[nwin[0] // 2] |
| 253 | + taps[nwins - 1] = taprightbottom |
| 254 | + self.taps = np.vstack(taps).reshape(nwins0, nwins1, nwin[0], nwin[1]) |
| 255 | + |
| 256 | + # define scalings |
| 257 | + if scalings is None: |
| 258 | + self.scalings = [1.0] * nwins |
| 259 | + else: |
| 260 | + self.scalings = scalings |
| 261 | + |
| 262 | + # check if operator is applied to all windows simultaneously |
| 263 | + self.simOp = False |
| 264 | + if Op.shape[1] == np.prod(dims): |
| 265 | + self.simOp = True |
| 266 | + self.Op = Op |
| 267 | + |
| 268 | + super().__init__( |
| 269 | + dtype=Op.dtype, |
| 270 | + dims=(nwins0, nwins1, int(dims[0] // nwins0), int(dims[1] // nwins1)), |
| 271 | + dimsd=dimsd, |
| 272 | + clinear=False, |
| 273 | + name=name, |
249 | 274 | ) |
250 | 275 |
|
251 | | - hstack = HStack( |
252 | | - [ |
253 | | - Restriction( |
254 | | - (nwin[0], dimsd[1]), range(win_in, win_end), axis=1, dtype=Op.dtype |
255 | | - ).H |
256 | | - for win_in, win_end in zip(dwin1_ins, dwin1_ends) |
257 | | - ] |
258 | | - ) |
259 | | - combining1 = BlockDiag([hstack] * nwins0) |
| 276 | + @reshaped() |
| 277 | + def _matvec(self, x: NDArray) -> NDArray: |
| 278 | + ncp = get_array_module(x) |
| 279 | + if self.tapertype is not None: |
| 280 | + self.taps = to_cupy_conditional(x, self.taps) |
| 281 | + y = ncp.zeros(self.dimsd, dtype=self.dtype) |
| 282 | + if self.simOp: |
| 283 | + x = self.Op @ x |
| 284 | + for iwin0 in range(self.dims[0]): |
| 285 | + for iwin1 in range(self.dims[1]): |
| 286 | + if self.simOp: |
| 287 | + xx = x[iwin0, iwin1].reshape(self.nwin) |
| 288 | + else: |
| 289 | + xx = self.Op.matvec(x[iwin0, iwin1].ravel()).reshape(self.nwin) |
| 290 | + if self.tapertype is not None: |
| 291 | + xxwin = self.taps[iwin0, iwin1] * xx |
| 292 | + else: |
| 293 | + xxwin = xx |
260 | 294 |
|
261 | | - combining0 = HStack( |
262 | | - [ |
263 | | - Restriction(dimsd, range(win_in, win_end), axis=0, dtype=Op.dtype).H |
264 | | - for win_in, win_end in zip(dwin0_ins, dwin0_ends) |
| 295 | + y[ |
| 296 | + self.dwins_inends[0][0][iwin0] : self.dwins_inends[0][1][iwin0], |
| 297 | + self.dwins_inends[1][0][iwin1] : self.dwins_inends[1][1][iwin1], |
| 298 | + ] += xxwin |
| 299 | + return y |
| 300 | + |
| 301 | + @reshaped |
| 302 | + def _rmatvec(self, x: NDArray) -> NDArray: |
| 303 | + ncp = get_array_module(x) |
| 304 | + ncp_sliding_window_view = get_sliding_window_view(x) |
| 305 | + if self.tapertype is not None: |
| 306 | + self.taps = to_cupy_conditional(x, self.taps) |
| 307 | + ywins = ncp_sliding_window_view(x, self.nwin)[ |
| 308 | + :: self.nwin[0] - self.nover[0], :: self.nwin[1] - self.nover[1] |
265 | 309 | ] |
266 | | - ) |
267 | | - Pop = LinearOperator(combining0 * combining1 * OOp) |
268 | | - Pop.dims, Pop.dimsd = ( |
269 | | - nwins0, |
270 | | - nwins1, |
271 | | - int(dims[0] // nwins0), |
272 | | - int(dims[1] // nwins1), |
273 | | - ), dimsd |
274 | | - Pop.name = name |
275 | | - return Pop |
| 310 | + if self.tapertype is not None: |
| 311 | + ywins = ywins * self.taps |
| 312 | + if self.simOp: |
| 313 | + y = self.Op.H @ ywins |
| 314 | + else: |
| 315 | + y = ncp.zeros(self.dims, dtype=self.dtype) |
| 316 | + for iwin0 in range(self.dims[0]): |
| 317 | + for iwin1 in range(self.dims[1]): |
| 318 | + y[iwin0, iwin1] = self.Op.rmatvec( |
| 319 | + ywins[iwin0, iwin1].ravel() |
| 320 | + ).reshape(self.dims[2], self.dims[3]) |
| 321 | + return y |
0 commit comments