Skip to content

Commit 65571ee

Browse files
authored
Merge pull request #569 from mrava87/feature-dtcwtclean
feat: added DTCWT 1d operator
2 parents e588fe3 + b56ab83 commit 65571ee

File tree

11 files changed

+387
-3
lines changed

11 files changed

+387
-3
lines changed

docs/source/installation.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,20 @@ of GPUs should install it prior to installing PyLops as described in :ref:`Optio
319319
In alphabetic order:
320320

321321

322+
dtcwt
323+
-----
324+
`dtcwt <https://dtcwt.readthedocs.io/en/0.12.0/>`_ is a library used to implement the DT-CWT operators.
325+
326+
Install it via ``pip`` with:
327+
328+
.. code-block:: bash
329+
330+
>> pip install dtcwt
331+
332+
322333
Devito
323334
------
324-
`Devito <https://github.com/devitocodes/devito>`_ is library used to solve PDEs via
335+
`Devito <https://github.com/devitocodes/devito>`_ is a library used to solve PDEs via
325336
the finite-difference method. It is used in PyLops to compute wavefields
326337
:py:class:`pylops.waveeqprocessing.AcousticWave2D`
327338

environment-dev-arm.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies:
2525
- black
2626
- pip:
2727
- devito
28+
- dtcwt
2829
- scikit-fmm
2930
- spgl1
3031
- pytest-runner

environment-dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ dependencies:
2626
- black
2727
- pip:
2828
- devito
29+
- dtcwt
2930
- scikit-fmm
3031
- spgl1
3132
- pytest-runner

examples/plot_dtcwt.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
Dual-Tree Complex Wavelet Transform
3+
===================================
4+
This example shows how to use the :py:class:`pylops.signalprocessing.DTCWT` operator to perform the
5+
1D Dual-Tree Complex Wavelet Transform on a (single or multi-dimensional) input array. Such a transform
6+
provides advantages over the DWT which lacks shift invariance in 1-D and directional sensitivity in N-D.
7+
"""
8+
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import pywt
12+
13+
import pylops
14+
15+
plt.close("all")
16+
17+
###############################################################################
18+
# To begin with, let's define two 1D arrays with a spike at slightly different location
19+
20+
n = 128
21+
x = np.zeros(n)
22+
x1 = np.zeros(n)
23+
24+
x[59] = 1
25+
x1[63] = 1
26+
27+
###############################################################################
28+
# We now create the DTCWT operator with the shape of our input array. The DTCWT transform
29+
# provides a Pyramid object that is internally flattened out into a vector. Here we re-obtain
30+
# the Pyramid object such that we can visualize the different scales indipendently.
31+
32+
level = 3
33+
DCOp = pylops.signalprocessing.DTCWT(dims=n, level=level)
34+
Xc = DCOp.get_pyramid(DCOp @ x)
35+
Xc1 = DCOp.get_pyramid(DCOp @ x1)
36+
37+
###############################################################################
38+
# To prove the superiority of the DTCWT transform over the DWT in shift-invariance,
39+
# let's also compute the DWT transform of these two signals and compare the coefficents
40+
# of both transform at level 3. As you will see, the coefficients change completely for
41+
# the DWT despite the two input signals are very similar; this is not the case for the
42+
# DCWT transform.
43+
44+
DOp = pylops.signalprocessing.DWT(dims=n, level=level, wavelet="sym7")
45+
X = pywt.array_to_coeffs(DOp @ x, DOp.sl, output_format="wavedecn")
46+
X1 = pywt.array_to_coeffs(DOp @ x1, DOp.sl, output_format="wavedecn")
47+
48+
fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, figsize=(10, 5))
49+
axs[0, 0].stem(np.abs(X[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
50+
axs[0, 0].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X[1]['d']))**2:.3f})")
51+
axs[0, 1].stem(np.abs(X1[1]["d"]), linefmt="k", markerfmt=".k", basefmt="k")
52+
axs[0, 1].set_title(f"DWT (Norm={np.linalg.norm(np.abs(X1[1]['d']))**2:.3f})")
53+
axs[1, 0].stem(np.abs(Xc.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
54+
axs[1, 0].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc.highpasses[2]))**2:.3f})")
55+
axs[1, 1].stem(np.abs(Xc1.highpasses[2]), linefmt="k", markerfmt=".k", basefmt="k")
56+
axs[1, 1].set_title(f"DCWT (Norm={np.linalg.norm(np.abs(Xc1.highpasses[2]))**2:.3f})")
57+
plt.tight_layout()
58+
59+
###################################################################################
60+
# The DTCWT can also be performed on multi-dimension arrays, where the parameter
61+
# ``axis`` is used to define the axis over which the transform is performed. Let's
62+
# just replicate our input signal over the second axis and see how the transform
63+
# will produce the same series of coefficients for all replicas.
64+
65+
nrepeat = 10
66+
x = np.repeat(np.random.rand(n, 1), 10, axis=1).T
67+
68+
level = 3
69+
DCOp = pylops.signalprocessing.DTCWT(dims=(nrepeat, n), level=level, axis=1)
70+
X = DCOp @ x
71+
72+
fig, axs = plt.subplots(1, 2, sharey=True, figsize=(10, 3))
73+
axs[0].imshow(X[0])
74+
axs[0].axis("tight")
75+
axs[0].set_xlabel("Coeffs")
76+
axs[0].set_ylabel("Replicas")
77+
axs[0].set_title("DTCWT Real")
78+
axs[1].imshow(X[1])
79+
axs[1].axis("tight")
80+
axs[1].set_xlabel("Coeffs")
81+
axs[1].set_title("DTCWT Imag")
82+
plt.tight_layout()

pylops/signalprocessing/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
DWT One dimensional Wavelet operator.
2525
DWT2D Two dimensional Wavelet operator.
2626
DCT Discrete Cosine Transform.
27-
Seislet Two dimensional Seislet operator.
27+
DTCWT Dual-Tree Complex Wavelet Transform.
2828
Radon2D Two dimensional Radon transform.
2929
Radon3D Three dimensional Radon transform.
30+
Seislet Two dimensional Seislet operator.
3031
Sliding1D 1D Sliding transform operator.
3132
Sliding2D 2D Sliding transform operator.
3233
Sliding3D 3D Sliding transform operator.
@@ -62,6 +63,8 @@
6263
from .dwt2d import *
6364
from .seislet import *
6465
from .dct import *
66+
from .dtcwt import *
67+
6568

6669
__all__ = [
6770
"FFT",
@@ -92,4 +95,5 @@
9295
"DWT2D",
9396
"Seislet",
9497
"DCT",
98+
"DTCWT",
9599
]

pylops/signalprocessing/dct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DCT(LinearOperator):
2929
axes : :obj:`int` or :obj:`list`, optional
3030
Axes over which the DCT is computed. If ``None``, the transform is applied
3131
over all axes.
32-
workers :obj:`int`, optional
32+
workers : :obj:`int`, optional
3333
Maximum number of workers to use for parallel computation. If negative,
3434
the value wraps around from os.cpu_count().
3535
dtype : :obj:`DTypeLike`, optional

pylops/signalprocessing/dtcwt.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
__all__ = ["DTCWT"]
2+
3+
from typing import Union
4+
5+
import numpy as np
6+
7+
from pylops import LinearOperator
8+
from pylops.utils import deps
9+
from pylops.utils._internal import _value_or_sized_to_tuple
10+
from pylops.utils.decorators import reshaped
11+
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray
12+
13+
dtcwt_message = deps.dtcwt_import("the dtcwt module")
14+
15+
if dtcwt_message is None:
16+
import dtcwt
17+
18+
19+
class DTCWT(LinearOperator):
20+
r"""Dual-Tree Complex Wavelet Transform
21+
22+
Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a
23+
multi-dimensional array of size ``dims``.
24+
25+
Note that the DTCWT operator is an overload of the ``dtcwt``
26+
implementation of the DT-CWT transform. Refer to
27+
https://dtcwt.readthedocs.io for a detailed description of the
28+
input parameters.
29+
30+
Parameters
31+
----------
32+
dims : :obj:`int` or :obj:`tuple`
33+
Number of samples for each dimension.
34+
birot : :obj:`str`, optional
35+
Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`.
36+
qshift : :obj:`str`, optional
37+
Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"`
38+
level : :obj:`int`, optional
39+
Number of levels of wavelet decomposition. Default is 3.
40+
include_scale : :obj:`bool`, optional
41+
Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False.
42+
axis : :obj:`int`, optional
43+
Axis on which the transform is performed.
44+
dtype : :obj:`DTypeLike`, optional
45+
Type of elements in input array.
46+
name : :obj:`str`, optional
47+
Name of operator (to be used by :func:`pylops.utils.describe.describe`)
48+
49+
Notes
50+
-----
51+
The DTCWT operator applies the dual-tree complex wavelet transform
52+
in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode
53+
from the ``dtcwt`` library.
54+
55+
The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain,
56+
which is composed of:
57+
- `lowpass` (coarsest scale lowpass signal);
58+
- `highpasses` (complex subband coefficients for corresponding scales);
59+
- `scales` (lowpass signal for corresponding scales finest to coarsest).
60+
61+
To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model
62+
the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients)
63+
are appended into one array using the `_coeff_to_array` method.
64+
65+
In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff`
66+
method and then the inverse transform is performed.
67+
68+
"""
69+
70+
def __init__(
71+
self,
72+
dims: Union[int, InputDimsLike],
73+
biort: str = "near_sym_a",
74+
qshift: str = "qshift_a",
75+
level: int = 3,
76+
include_scale: bool = False,
77+
axis: int = -1,
78+
dtype: DTypeLike = "float64",
79+
name: str = "C",
80+
) -> None:
81+
if dtcwt_message is not None:
82+
raise NotImplementedError(dtcwt_message)
83+
84+
dims = _value_or_sized_to_tuple(dims)
85+
self.ndim = len(dims)
86+
self.axis = axis
87+
88+
self.otherdims = int(np.prod(dims) / dims[self.axis])
89+
self.dims_swapped = list(dims)
90+
self.dims_swapped[0], self.dims_swapped[self.axis] = (
91+
self.dims_swapped[self.axis],
92+
self.dims_swapped[0],
93+
)
94+
self.dims_swapped = tuple(self.dims_swapped)
95+
self.level = level
96+
self.include_scale = include_scale
97+
98+
# dry-run of transform to find dimensions of coefficients at different levels
99+
self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift)
100+
self._interpret_coeffs(dims, self.axis)
101+
102+
dimsd = list(dims)
103+
dimsd[self.axis] = self.coeff_array_size
104+
self.dimsd_swapped = list(dimsd)
105+
self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = (
106+
self.dimsd_swapped[self.axis],
107+
self.dimsd_swapped[0],
108+
)
109+
self.dimsd_swapped = tuple(self.dimsd_swapped)
110+
dimsd = tuple(
111+
[
112+
2,
113+
]
114+
+ dimsd
115+
)
116+
117+
super().__init__(
118+
dtype=np.dtype(dtype),
119+
clinear=False,
120+
dims=dims,
121+
dimsd=dimsd,
122+
name=name,
123+
)
124+
125+
def _interpret_coeffs(self, dims, axis):
126+
x = np.ones(dims[axis])
127+
pyr = self._transform.forward(
128+
x, nlevels=self.level, include_scale=self.include_scale
129+
)
130+
self.lowpass_size = pyr.lowpass.size
131+
self.coeff_array_size = self.lowpass_size
132+
self.highpass_sizes = []
133+
for _h in pyr.highpasses:
134+
self.highpass_sizes.append(_h.size)
135+
self.coeff_array_size += _h.size
136+
137+
def _nd_to_2d(self, arr_nd):
138+
arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze()
139+
return arr_2d
140+
141+
def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray:
142+
highpass_coeffs = np.vstack([h for h in pyr.highpasses])
143+
coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0)
144+
return coeffs
145+
146+
def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid:
147+
lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims))
148+
_ptr = 0
149+
highpasses = ()
150+
for _sl in self.highpass_sizes:
151+
_h = X[_ptr : _ptr + _sl]
152+
_ptr += _sl
153+
_h = _h.reshape(-1, self.otherdims)
154+
highpasses += (_h,)
155+
return dtcwt.Pyramid(lowpass, highpasses)
156+
157+
def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid:
158+
"""Return Pyramid object from flat real-valued array"""
159+
return self._array_to_coeff(x[0] + 1j * x[1])
160+
161+
@reshaped
162+
def _matvec(self, x: NDArray) -> NDArray:
163+
x = x.swapaxes(self.axis, 0)
164+
y = self._nd_to_2d(x)
165+
y = self._coeff_to_array(
166+
self._transform.forward(
167+
y, nlevels=self.level, include_scale=self.include_scale
168+
)
169+
)
170+
y = y.reshape(self.dimsd_swapped)
171+
y = y.swapaxes(self.axis, 0)
172+
y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]])
173+
return y
174+
175+
@reshaped
176+
def _rmatvec(self, x: NDArray) -> NDArray:
177+
x = x[0] + 1j * x[1]
178+
x = x.swapaxes(self.axis, 0)
179+
y = self._transform.inverse(self._array_to_coeff(x))
180+
y = y.reshape(self.dims_swapped)
181+
y = y.swapaxes(self.axis, 0)
182+
return y

pylops/utils/deps.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = [
22
"cupy_enabled",
33
"devito_enabled",
4+
"dtcwt_enabled",
45
"numba_enabled",
56
"pyfftw_enabled",
67
"pywt_enabled",
@@ -67,6 +68,23 @@ def devito_import(message: Optional[str] = None) -> str:
6768
return devito_message
6869

6970

71+
def dtcwt_import(message: Optional[str] = None) -> str:
72+
if dtcwt_enabled:
73+
try:
74+
import dtcwt # noqa: F401
75+
76+
dtcwt_message = None
77+
except Exception as e:
78+
dtcwt_message = f"Failed to import dtcwt (error:{e})."
79+
else:
80+
dtcwt_message = (
81+
f"Dtcwt not available. "
82+
f"In order to be able to use "
83+
f'{message} run "pip install dtcwt".'
84+
)
85+
return dtcwt_message
86+
87+
7088
def numba_import(message: Optional[str] = None) -> str:
7189
if numba_enabled:
7290
try:
@@ -187,6 +205,7 @@ def sympy_import(message: Optional[str] = None) -> str:
187205
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
188206
)
189207
devito_enabled = util.find_spec("devito") is not None
208+
dtcwt_enabled = util.find_spec("dtcwt") is not None
190209
numba_enabled = util.find_spec("numba") is not None
191210
pyfftw_enabled = util.find_spec("pyfftw") is not None
192211
pywt_enabled = util.find_spec("pywt") is not None

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ advanced = [
4343
"PyWavelets",
4444
"scikit-fmm",
4545
"spgl1",
46+
"dtcwt",
4647
]
4748

4849
[tool.setuptools.packages.find]

0 commit comments

Comments
 (0)