|
| 1 | +__all__ = ["DTCWT"] |
| 2 | + |
| 3 | +from typing import Union |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +from pylops import LinearOperator |
| 8 | +from pylops.utils import deps |
| 9 | +from pylops.utils._internal import _value_or_sized_to_tuple |
| 10 | +from pylops.utils.decorators import reshaped |
| 11 | +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray |
| 12 | + |
| 13 | +dtcwt_message = deps.dtcwt_import("the dtcwt module") |
| 14 | + |
| 15 | +if dtcwt_message is None: |
| 16 | + import dtcwt |
| 17 | + |
| 18 | + |
| 19 | +class DTCWT(LinearOperator): |
| 20 | + r"""Dual-Tree Complex Wavelet Transform |
| 21 | +
|
| 22 | + Perform 1D Dual-Tree Complex Wavelet Transform along an ``axis`` of a |
| 23 | + multi-dimensional array of size ``dims``. |
| 24 | +
|
| 25 | + Note that the DTCWT operator is an overload of the ``dtcwt`` |
| 26 | + implementation of the DT-CWT transform. Refer to |
| 27 | + https://dtcwt.readthedocs.io for a detailed description of the |
| 28 | + input parameters. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + dims : :obj:`int` or :obj:`tuple` |
| 33 | + Number of samples for each dimension. |
| 34 | + birot : :obj:`str`, optional |
| 35 | + Level 1 wavelets to use. See :py:func:`dtcwt.coeffs.birot`. Default is `"near_sym_a"`. |
| 36 | + qshift : :obj:`str`, optional |
| 37 | + Level >= 2 wavelets to use. See :py:func:`dtcwt.coeffs.qshift`. Default is `"qshift_a"` |
| 38 | + level : :obj:`int`, optional |
| 39 | + Number of levels of wavelet decomposition. Default is 3. |
| 40 | + include_scale : :obj:`bool`, optional |
| 41 | + Include scales in pyramid. See :py:class:`dtcwt.Pyramid`. Default is False. |
| 42 | + axis : :obj:`int`, optional |
| 43 | + Axis on which the transform is performed. |
| 44 | + dtype : :obj:`DTypeLike`, optional |
| 45 | + Type of elements in input array. |
| 46 | + name : :obj:`str`, optional |
| 47 | + Name of operator (to be used by :func:`pylops.utils.describe.describe`) |
| 48 | +
|
| 49 | + Notes |
| 50 | + ----- |
| 51 | + The DTCWT operator applies the dual-tree complex wavelet transform |
| 52 | + in forward mode and the dual-tree complex inverse wavelet transform in adjoint mode |
| 53 | + from the ``dtcwt`` library. |
| 54 | +
|
| 55 | + The ``dtcwt`` library uses a Pyramid object to represent the signal in the transformed domain, |
| 56 | + which is composed of: |
| 57 | + - `lowpass` (coarsest scale lowpass signal); |
| 58 | + - `highpasses` (complex subband coefficients for corresponding scales); |
| 59 | + - `scales` (lowpass signal for corresponding scales finest to coarsest). |
| 60 | +
|
| 61 | + To make the dtcwt forward() and inverse() functions compatible with PyLops, in forward model |
| 62 | + the Pyramid object is flattened out and all coefficients (high-pass and low pass coefficients) |
| 63 | + are appended into one array using the `_coeff_to_array` method. |
| 64 | +
|
| 65 | + In adjoint mode, the input array is transformed back into a Pyramid object using the `_array_to_coeff` |
| 66 | + method and then the inverse transform is performed. |
| 67 | +
|
| 68 | + """ |
| 69 | + |
| 70 | + def __init__( |
| 71 | + self, |
| 72 | + dims: Union[int, InputDimsLike], |
| 73 | + biort: str = "near_sym_a", |
| 74 | + qshift: str = "qshift_a", |
| 75 | + level: int = 3, |
| 76 | + include_scale: bool = False, |
| 77 | + axis: int = -1, |
| 78 | + dtype: DTypeLike = "float64", |
| 79 | + name: str = "C", |
| 80 | + ) -> None: |
| 81 | + if dtcwt_message is not None: |
| 82 | + raise NotImplementedError(dtcwt_message) |
| 83 | + |
| 84 | + dims = _value_or_sized_to_tuple(dims) |
| 85 | + self.ndim = len(dims) |
| 86 | + self.axis = axis |
| 87 | + |
| 88 | + self.otherdims = int(np.prod(dims) / dims[self.axis]) |
| 89 | + self.dims_swapped = list(dims) |
| 90 | + self.dims_swapped[0], self.dims_swapped[self.axis] = ( |
| 91 | + self.dims_swapped[self.axis], |
| 92 | + self.dims_swapped[0], |
| 93 | + ) |
| 94 | + self.dims_swapped = tuple(self.dims_swapped) |
| 95 | + self.level = level |
| 96 | + self.include_scale = include_scale |
| 97 | + |
| 98 | + # dry-run of transform to find dimensions of coefficients at different levels |
| 99 | + self._transform = dtcwt.Transform1d(biort=biort, qshift=qshift) |
| 100 | + self._interpret_coeffs(dims, self.axis) |
| 101 | + |
| 102 | + dimsd = list(dims) |
| 103 | + dimsd[self.axis] = self.coeff_array_size |
| 104 | + self.dimsd_swapped = list(dimsd) |
| 105 | + self.dimsd_swapped[0], self.dimsd_swapped[self.axis] = ( |
| 106 | + self.dimsd_swapped[self.axis], |
| 107 | + self.dimsd_swapped[0], |
| 108 | + ) |
| 109 | + self.dimsd_swapped = tuple(self.dimsd_swapped) |
| 110 | + dimsd = tuple( |
| 111 | + [ |
| 112 | + 2, |
| 113 | + ] |
| 114 | + + dimsd |
| 115 | + ) |
| 116 | + |
| 117 | + super().__init__( |
| 118 | + dtype=np.dtype(dtype), |
| 119 | + clinear=False, |
| 120 | + dims=dims, |
| 121 | + dimsd=dimsd, |
| 122 | + name=name, |
| 123 | + ) |
| 124 | + |
| 125 | + def _interpret_coeffs(self, dims, axis): |
| 126 | + x = np.ones(dims[axis]) |
| 127 | + pyr = self._transform.forward( |
| 128 | + x, nlevels=self.level, include_scale=self.include_scale |
| 129 | + ) |
| 130 | + self.lowpass_size = pyr.lowpass.size |
| 131 | + self.coeff_array_size = self.lowpass_size |
| 132 | + self.highpass_sizes = [] |
| 133 | + for _h in pyr.highpasses: |
| 134 | + self.highpass_sizes.append(_h.size) |
| 135 | + self.coeff_array_size += _h.size |
| 136 | + |
| 137 | + def _nd_to_2d(self, arr_nd): |
| 138 | + arr_2d = arr_nd.reshape(self.dims[self.axis], -1).squeeze() |
| 139 | + return arr_2d |
| 140 | + |
| 141 | + def _coeff_to_array(self, pyr: dtcwt.Pyramid) -> NDArray: |
| 142 | + highpass_coeffs = np.vstack([h for h in pyr.highpasses]) |
| 143 | + coeffs = np.concatenate((highpass_coeffs, pyr.lowpass), axis=0) |
| 144 | + return coeffs |
| 145 | + |
| 146 | + def _array_to_coeff(self, X: NDArray) -> dtcwt.Pyramid: |
| 147 | + lowpass = (X[-self.lowpass_size :].real).reshape((-1, self.otherdims)) |
| 148 | + _ptr = 0 |
| 149 | + highpasses = () |
| 150 | + for _sl in self.highpass_sizes: |
| 151 | + _h = X[_ptr : _ptr + _sl] |
| 152 | + _ptr += _sl |
| 153 | + _h = _h.reshape(-1, self.otherdims) |
| 154 | + highpasses += (_h,) |
| 155 | + return dtcwt.Pyramid(lowpass, highpasses) |
| 156 | + |
| 157 | + def get_pyramid(self, x: NDArray) -> dtcwt.Pyramid: |
| 158 | + """Return Pyramid object from flat real-valued array""" |
| 159 | + return self._array_to_coeff(x[0] + 1j * x[1]) |
| 160 | + |
| 161 | + @reshaped |
| 162 | + def _matvec(self, x: NDArray) -> NDArray: |
| 163 | + x = x.swapaxes(self.axis, 0) |
| 164 | + y = self._nd_to_2d(x) |
| 165 | + y = self._coeff_to_array( |
| 166 | + self._transform.forward( |
| 167 | + y, nlevels=self.level, include_scale=self.include_scale |
| 168 | + ) |
| 169 | + ) |
| 170 | + y = y.reshape(self.dimsd_swapped) |
| 171 | + y = y.swapaxes(self.axis, 0) |
| 172 | + y = np.concatenate([y.real[np.newaxis], y.imag[np.newaxis]]) |
| 173 | + return y |
| 174 | + |
| 175 | + @reshaped |
| 176 | + def _rmatvec(self, x: NDArray) -> NDArray: |
| 177 | + x = x[0] + 1j * x[1] |
| 178 | + x = x.swapaxes(self.axis, 0) |
| 179 | + y = self._transform.inverse(self._array_to_coeff(x)) |
| 180 | + y = y.reshape(self.dims_swapped) |
| 181 | + y = y.swapaxes(self.axis, 0) |
| 182 | + return y |
0 commit comments