|
3 | 3 | from numpy.testing import assert_array_almost_equal |
4 | 4 | from scipy.sparse.linalg import lsqr |
5 | 5 |
|
6 | | -from pylops.signalprocessing import DWT, DWT2D |
| 6 | +from pylops.signalprocessing import DWT, DWT2D, DWTND |
7 | 7 | from pylops.utils import dottest |
8 | 8 |
|
9 | 9 | par1 = {"ny": 7, "nx": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real |
10 | 10 | par2 = {"ny": 7, "nx": 9, "nt": 10, "imag": 1j, "dtype": "complex64"} # complex |
| 11 | +par3 = {"ny": 7, "nx": 9, "nz": 9, "nt": 10, "imag": 0, "dtype": "float32"} # real 4D |
| 12 | +par4 = { |
| 13 | + "ny": 7, |
| 14 | + "nx": 9, |
| 15 | + "nz": 9, |
| 16 | + "nt": 10, |
| 17 | + "imag": 1j, |
| 18 | + "dtype": "complex64", |
| 19 | +} # complex 4D |
11 | 20 |
|
12 | 21 | np.random.seed(10) |
13 | 22 |
|
@@ -133,3 +142,56 @@ def test_DWT2D_3dsignal(par): |
133 | 142 |
|
134 | 143 | assert_array_almost_equal(x.ravel(), xadj, decimal=8) |
135 | 144 | assert_array_almost_equal(x.ravel(), xinv, decimal=8) |
| 145 | + |
| 146 | + |
| 147 | +@pytest.mark.parametrize("par", [(par3), (par4)]) |
| 148 | +def test_DWTND_3dsignal(par): |
| 149 | + """Dot-test and inversion for DWTND operator for 3d signal""" |
| 150 | + DWTop = DWTND( |
| 151 | + dims=(par["nt"], par["nx"], par["ny"]), axes=(0, 1, 2), wavelet="haar", level=3 |
| 152 | + ) |
| 153 | + x = np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) + par[ |
| 154 | + "imag" |
| 155 | + ] * np.random.normal(0.0, 1.0, (par["nt"], par["nx"], par["ny"])) |
| 156 | + |
| 157 | + assert dottest( |
| 158 | + DWTop, DWTop.shape[0], DWTop.shape[1], complexflag=0 if par["imag"] == 0 else 3 |
| 159 | + ) |
| 160 | + |
| 161 | + y = DWTop * x.ravel() |
| 162 | + xadj = DWTop.H * y # adjoint is same as inverse for dwt |
| 163 | + xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0] |
| 164 | + |
| 165 | + assert_array_almost_equal(x.ravel(), xadj, decimal=8) |
| 166 | + assert_array_almost_equal(x.ravel(), xinv, decimal=8) |
| 167 | + |
| 168 | + |
| 169 | +@pytest.mark.parametrize("par", [(par3), (par4)]) |
| 170 | +def test_DWTND_4dsignal(par): |
| 171 | + """Dot-test and inversion for DWTND operator for 4d signal""" |
| 172 | + for axes in [(0, 1, 2), (0, 2, 3), (1, 2, 3), (0, 1, 3), (0, 1, 2, 3)]: |
| 173 | + DWTop = DWTND( |
| 174 | + dims=(par["nt"], par["nx"], par["ny"], par["nz"]), |
| 175 | + axes=axes, |
| 176 | + wavelet="haar", |
| 177 | + level=3, |
| 178 | + ) |
| 179 | + x = np.random.normal( |
| 180 | + 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"]) |
| 181 | + ) + par["imag"] * np.random.normal( |
| 182 | + 0.0, 1.0, (par["nt"], par["nx"], par["ny"], par["nz"]) |
| 183 | + ) |
| 184 | + |
| 185 | + assert dottest( |
| 186 | + DWTop, |
| 187 | + DWTop.shape[0], |
| 188 | + DWTop.shape[1], |
| 189 | + complexflag=0 if par["imag"] == 0 else 3, |
| 190 | + ) |
| 191 | + |
| 192 | + y = DWTop * x.ravel() |
| 193 | + xadj = DWTop.H * y # adjoint is same as inverse for dwt |
| 194 | + xinv = lsqr(DWTop, y, damp=1e-10, iter_lim=10, atol=1e-8, btol=1e-8, show=0)[0] |
| 195 | + |
| 196 | + assert_array_almost_equal(x.ravel(), xadj, decimal=8) |
| 197 | + assert_array_almost_equal(x.ravel(), xinv, decimal=8) |
0 commit comments