Skip to content

Commit 4860e7a

Browse files
authored
Merge pull request #583 from solldavid/sollberger/3d-dwt
Feature: N-dimensional discrete wavelet transforms
2 parents b19a63c + 9b0c9b9 commit 4860e7a

File tree

7 files changed

+253
-3
lines changed

7 files changed

+253
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,4 @@ A list of video tutorials to learn more about PyLops:
150150
* Wei Zhang, ZhangWeiGeo
151151
* Fedor Goncharov, fedor-goncharov
152152
* Alex Rakowski, alex-rakowski
153+
* David Sollberger, solldavid

docs/source/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ Signal processing
102102
Shift
103103
DWT
104104
DWT2D
105+
DWTND
105106
DCT
106107
DTCWT
107108
Seislet

docs/source/credits.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ Contributors
2222
* `Wei Zhang <https://github.com/ZhangWeiGeo>`_, ZhangWeiGeo
2323
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
2424
* `Alex Rakowski <https://github.com/alex-rakowski>`_, alex-rakowski
25+
* `David Sollberger <https://github.com/solldavid>`_, solldavid

examples/plot_wavelet.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
22
Wavelet transform
33
=================
4-
This example shows how to use the :py:class:`pylops.DWT` and
5-
:py:class:`pylops.DWT2D` operators to perform 1- and 2-dimensional DWT.
4+
This example shows how to use the :py:class:`pylops.DWT`,
5+
:py:class:`pylops.DWT2D`, and :py:class:`pylops.DWTND` operators
6+
to perform 1-, 2-, and N-dimensional DWT.
67
"""
78
import matplotlib.pyplot as plt
89
import numpy as np
@@ -67,3 +68,46 @@
6768
axs[1, 1].set_title("DWT2 coefficients (zeroed)")
6869
axs[1, 1].axis("tight")
6970
plt.tight_layout()
71+
72+
###############################################################################
73+
# Let us now try the same with a 3D volumetric model, where we use the
74+
# N-dimensional DWT. This time, we only retain 10 percent of the coefficients
75+
# of the DWT.
76+
77+
nx = 128
78+
ny = 256
79+
nz = 128
80+
81+
x = np.arange(nx)
82+
y = np.arange(ny)
83+
z = np.arange(nz)
84+
85+
xx, yy, zz = np.meshgrid(x, y, z, indexing="ij")
86+
# Generate a 3D model with two block anomalies
87+
m = np.ones_like(xx, dtype=float)
88+
block1 = (xx > 10) & (xx < 60) & (yy > 100) & (yy < 150) & (zz > 20) & (zz < 70)
89+
block2 = (xx > 70) & (xx < 80) & (yy > 100) & (yy < 200) & (zz > 10) & (zz < 50)
90+
m[block1] = 1.2
91+
m[block2] = 0.8
92+
Wop = pylops.signalprocessing.DWTND((nx, ny, nz), wavelet="haar", level=3)
93+
y = Wop * m
94+
95+
ratio = 0.1
96+
yf = y.copy()
97+
yf.flat[int(ratio * y.size) :] = 0
98+
iminv = Wop.H * yf
99+
100+
fig, axs = plt.subplots(2, 2, figsize=(6, 6))
101+
axs[0, 0].imshow(m[:, :, 30], cmap="gray")
102+
axs[0, 0].set_title("Model (Slice at z=30)")
103+
axs[0, 0].axis("tight")
104+
axs[0, 1].imshow(y[:, :, 90], cmap="gray_r")
105+
axs[0, 1].set_title("DWTNT coefficients")
106+
axs[0, 1].axis("tight")
107+
axs[1, 0].imshow(iminv[:, :, 30], cmap="gray")
108+
axs[1, 0].set_title("Reconstructed model (Slice at z=30)")
109+
axs[1, 0].axis("tight")
110+
axs[1, 1].imshow(yf[:, :, 90], cmap="gray_r")
111+
axs[1, 1].set_title("DWTNT coefficients (zeroed)")
112+
axs[1, 1].axis("tight")
113+
plt.tight_layout()

pylops/signalprocessing/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Shift Fractional Shift operator.
2424
DWT One dimensional Wavelet operator.
2525
DWT2D Two dimensional Wavelet operator.
26+
DWTND N-dimensional Wavelet operator.
2627
DCT Discrete Cosine Transform.
2728
DTCWT Dual-Tree Complex Wavelet Transform.
2829
Radon2D Two dimensional Radon transform.
@@ -61,6 +62,7 @@
6162
from .fredholm1 import *
6263
from .dwt import *
6364
from .dwt2d import *
65+
from .dwtnd import *
6466
from .seislet import *
6567
from .dct import *
6668
from .dtcwt import *
@@ -93,6 +95,7 @@
9395
"Fredholm1",
9496
"DWT",
9597
"DWT2D",
98+
"DWTND",
9699
"Seislet",
97100
"DCT",
98101
"DTCWT",

pylops/signalprocessing/dwtnd.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
__all__ = ["DWTND"]
2+
3+
import logging
4+
from math import ceil, log
5+
6+
import numpy as np
7+
8+
from pylops import LinearOperator
9+
from pylops.basicoperators import Pad
10+
from pylops.utils import deps
11+
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
12+
13+
from .dwt import _adjointwavelet, _checkwavelet
14+
15+
pywt_message = deps.pywt_import("the dwtnd module")
16+
17+
if pywt_message is None:
18+
import pywt
19+
20+
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.WARNING)
21+
22+
23+
class DWTND(LinearOperator):
24+
"""N-dimensional Wavelet operator.
25+
26+
Apply ND-Wavelet transform along N ``axes`` of a
27+
multi-dimensional array of size ``dims``.
28+
29+
Note that the Wavelet operator is an overload of the ``pywt``
30+
implementation of the wavelet transform. Refer to
31+
https://pywavelets.readthedocs.io for a detailed description of the
32+
input parameters.
33+
34+
Defaults to a 3D wavelet transform along the last three dimensions
35+
of the input array.
36+
37+
Parameters
38+
----------
39+
dims : :obj:`tuple`
40+
Number of samples for each dimension
41+
axes : :obj:`int`, optional
42+
Axis along which DWTND is applied
43+
wavelet : :obj:`str`, optional
44+
Name of wavelet type. Use :func:`pywt.wavelist(kind='discrete')` for
45+
a list of available wavelets.
46+
level : :obj:`int`, optional
47+
Number of scaling levels (must be >=0).
48+
dtype : :obj:`str`, optional
49+
Type of elements in input array.
50+
name : :obj:`str`, optional
51+
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
52+
53+
Attributes
54+
----------
55+
shape : :obj:`tuple`
56+
Operator shape
57+
explicit : :obj:`bool`
58+
Operator contains a matrix that can be solved explicitly
59+
(``True``) or not (``False``)
60+
61+
Raises
62+
------
63+
ModuleNotFoundError
64+
If ``pywt`` is not installed
65+
ValueError
66+
If ``wavelet`` does not belong to ``pywt.families``
67+
68+
Notes
69+
-----
70+
The Wavelet operator applies the N-dimensional multilevel Discrete
71+
Wavelet Transform (DWTN) in forward mode and the N-dimensional multilevel
72+
Inverse Discrete Wavelet Transform (IDWTN) in adjoint mode.
73+
74+
"""
75+
76+
def __init__(
77+
self,
78+
dims: InputDimsLike,
79+
axes: InputDimsLike = (-3, -2, -1),
80+
wavelet: str = "haar",
81+
level: int = 1,
82+
dtype: DTypeLike = "float64",
83+
name: str = "D",
84+
) -> None:
85+
if pywt_message is not None:
86+
raise ModuleNotFoundError(pywt_message)
87+
_checkwavelet(wavelet)
88+
89+
# define padding for length to be power of 2
90+
ndimpow2 = [max(2 ** ceil(log(dims[ax], 2)), 2**level) for ax in axes]
91+
pad = [(0, 0)] * len(dims)
92+
for i, ax in enumerate(axes):
93+
pad[ax] = (0, ndimpow2[i] - dims[ax])
94+
self.pad = Pad(dims, pad)
95+
self.axes = axes
96+
dimsd = list(dims)
97+
for i, ax in enumerate(axes):
98+
dimsd[ax] = ndimpow2[i]
99+
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
100+
101+
# apply transform once again to find out slices
102+
_, self.sl = pywt.coeffs_to_array(
103+
pywt.wavedecn(
104+
np.ones(self.dimsd),
105+
wavelet=wavelet,
106+
level=level,
107+
mode="periodization",
108+
axes=self.axes,
109+
),
110+
axes=self.axes,
111+
)
112+
self.wavelet = wavelet
113+
self.waveletadj = _adjointwavelet(wavelet)
114+
self.level = level
115+
116+
def _matvec(self, x: NDArray) -> NDArray:
117+
x = self.pad.matvec(x)
118+
x = np.reshape(x, self.dimsd)
119+
y = pywt.coeffs_to_array(
120+
pywt.wavedecn(
121+
x,
122+
wavelet=self.wavelet,
123+
level=self.level,
124+
mode="periodization",
125+
axes=self.axes,
126+
),
127+
axes=(self.axes),
128+
)[0]
129+
return y.ravel()
130+
131+
def _rmatvec(self, x: NDArray) -> NDArray:
132+
x = np.reshape(x, self.dimsd)
133+
x = pywt.array_to_coeffs(x, self.sl, output_format="wavedecn")
134+
y = pywt.waverecn(
135+
x, wavelet=self.waveletadj, mode="periodization", axes=self.axes
136+
)
137+
y = self.pad.rmatvec(y.ravel())
138+
return y

pytests/test_dwts.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,20 @@
33
from numpy.testing import assert_array_almost_equal
44
from scipy.sparse.linalg import lsqr
55

6-
from pylops.signalprocessing import DWT, DWT2D
6+
from pylops.signalprocessing import DWT, DWT2D, DWTND
77
from pylops.utils import dottest
88

99
par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real
1010
par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex
11+
par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D
12+
par4 = {
13+
"ny": 7,
14+
"nx": 9,
15+
"nz": 9,
16+
"nt": 10,
17+
"imag": 1j,
18+
"dtype": "complex64",
19+
} # complex 4D
1120

1221
np.random.seed(10)
1322

@@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par):
133142

134143
assert_array_almost_equal(x.ravel(), xadj, decimal=8)
135144
assert_array_almost_equal(x.ravel(), xinv, decimal=8)
145+
146+
147+
@pytest.mark.parametrize("par", [(par3), (par4)])
148+
def test_DWTND_3dsignal(par):
149+
"""Dot-test and inversion for DWTND operator for 3d signal"""
150+
DWTop = DWTND(
151+
dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3
152+
)
153+
x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[
154+
"imag"
155+
] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"]))
156+
157+
assert dottest(
158+
DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3
159+
)
160+
161+
y = DWTop * x.ravel()
162+
xadj = DWTop.H * y # adjoint is same as inverse for dwt
163+
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]
164+
165+
assert_array_almost_equal(x.ravel(), xadj, decimal=8)
166+
assert_array_almost_equal(x.ravel(), xinv, decimal=8)
167+
168+
169+
@pytest.mark.parametrize("par", [(par3), (par4)])
170+
def test_DWTND_4dsignal(par):
171+
"""Dot-test and inversion for DWTND operator for 4d signal"""
172+
for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]:
173+
DWTop = DWTND(
174+
dims=(par["nt"], par["nx"], par["ny"], par["nz"]),
175+
axes=axes,
176+
wavelet="haar",
177+
level=3,
178+
)
179+
x = np.random.normal(
180+
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
181+
) + par["imag"] * np.random.normal(
182+
0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"])
183+
)
184+
185+
assert dottest(
186+
DWTop,
187+
DWTop.shape[0],
188+
DWTop.shape[1],
189+
complexflag=0 if par["imag"] == 0 else 3,
190+
)
191+
192+
y = DWTop * x.ravel()
193+
xadj = DWTop.H * y # adjoint is same as inverse for dwt
194+
xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0]
195+
196+
assert_array_almost_equal(x.ravel(), xadj, decimal=8)
197+
assert_array_almost_equal(x.ravel(), xinv, decimal=8)

0 commit comments

Comments
 (0)