diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 0081a0ad..c711fe76 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -28,6 +28,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner @@ -39,3 +40,4 @@ dependencies: - image - flake8 - mypy + \ No newline at end of file diff --git a/environment-dev.yml b/environment-dev.yml index eb51c4dc..135319f7 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -29,6 +29,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner diff --git a/examples/plot_udct.py b/examples/plot_udct.py new file mode 100644 index 00000000..cf92035a --- /dev/null +++ b/examples/plot_udct.py @@ -0,0 +1,48 @@ +""" +Uniform Discrete Curvelet Transform +=================================== +This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the +Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. +""" + +import numpy as np +from ucurv import udct, zoneplate, ucurvfwd, ucurvinv, ucurv2d_show +import matplotlib.pyplot as plt +import pylops +plt.close("all") + +sz = [512, 512] +cfg = [[3, 3], [6, 6]] +res = len(cfg) +rsq = zoneplate(sz) +img = rsq - np.mean(rsq) + +transform = udct(sz, cfg, complex=False, high="curvelet") + +imband = ucurvfwd(img, transform) +plt.figure(figsize=(20, 60)) +print(imband.keys()) +plt.imshow(np.abs(ucurv2d_show(imband, transform))) +# plt.show() + +recon = ucurvinv(imband, transform) + +err = img - recon +print(np.max(np.abs(err))) +plt.figure(figsize=(20, 60)) +plt.imshow(np.real(np.concatenate((img, recon, err), axis=1))) + +plt.figure() +plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) +# plt.show() + +################################################################################ + + +sz = [256, 256] +cfg = [[3, 3], [6, 6]] +x = np.random.rand(256 * 256) +y = np.random.rand(262144) +F = pylops.signalprocessing.UDCT(sz, cfg) +print(np.dot(y, F * x)) +print(np.dot(x, F.T * y)) diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index a8e5ed65..b4f8fd7b 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -35,6 +35,7 @@ Patch2D 2D Patching transform operator. Patch3D 3D Patching transform operator. Fredholm1 Fredholm integral of first kind. + UDCT Uniform Discrete Curvelet Transform """ @@ -66,6 +67,7 @@ from .seislet import * from .dct import * from .dtcwt import * +from .udct import * __all__ = [ diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py new file mode 100644 index 00000000..ccafb5db --- /dev/null +++ b/pylops/signalprocessing/udct.py @@ -0,0 +1,63 @@ +__all__ = ["UDCT"] + +import numpy as np + +from pylops import LinearOperator +from pylops.utils import deps +from pylops.utils.decorators import reshaped +from pylops.utils.typing import NDArray +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands + +ucurv_message = deps.ucurv_import("the ucurv module") + + +class UDCT(LinearOperator): + r"""Uniform Discrete Curvelet Transform + + Perform the multidimensional discrete curvelet transforms + + The UDCT operator is a wraparound of the ucurvfwd and ucurvinv + calls in the UCURV package. Refer to + https://ucurv.readthedocs.io for a detailed description of the + input parameters. + + Parameters + ---------- + udct : :obj:`DTypeLike`, optional + Type of elements in input array. + dtype : :obj:`DTypeLike`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Notes + ----- + The UDCT operator applies the uniform discrete curvelet transform + in forward and adjoint modes from the ``ucurv`` library. + + The ``ucurv`` library uses a udct object to represent all the parameters + of the multidimensional transform. The udct object have to be created with the size + of the data need to be transformed, and the cfg parameter which control the + number of resolution and direction. + """ + def __init__(self, sz, cfg, complex=False, sparse=False, dtype=None): + self.udct = udct(sz, cfg, complex, sparse) + self.shape = (self.udct.len, np.prod(sz)) + self.dtype = np.dtype(dtype) + self.explicit = False + self.rmatvec_count = 0 + self.matvec_count = 0 + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: + img = x.reshape(self.udct.sz) + band = ucurvfwd(img, self.udct) + bvec = bands2vec(band) + return bvec + + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: + band = vec2bands(x, self.udct) + recon = ucurvinv(band, self.udct) + recon2 = recon.reshape(self.udct.sz) + return recon2 diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index ecf69a95..744b5aa3 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -3,6 +3,7 @@ "jax_enabled", "devito_enabled", "dtcwt_enabled", + "ucurv_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -76,7 +77,6 @@ def jax_import(message: Optional[str] = None) -> str: "'pip install jax'; " "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" ) - return jax_message @@ -114,6 +114,23 @@ def dtcwt_import(message: Optional[str] = None) -> str: return dtcwt_message +def ucurv_import(message: Optional[str] = None) -> str: + if ucurv_enabled: + try: + import ucurv # noqa: F401 + + ucurv_message = None + except Exception as e: + ucurv_message = f"Failed to import ucurv (error:{e})." + else: + ucurv_message = ( + f"UCURV not available. " + f"In order to be able to use " + f'{message} run "pip install ucurv".' + ) + return ucurv_message + + def numba_import(message: Optional[str] = None) -> str: if numba_enabled: try: @@ -238,6 +255,7 @@ def sympy_import(message: Optional[str] = None) -> str: ) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None +ucurv_enabled = util.find_spec("ucurv") is not None numba_enabled = util.find_spec("numba") is not None pyfftw_enabled = util.find_spec("pyfftw") is not None pywt_enabled = util.find_spec("pywt") is not None diff --git a/pyproject.toml b/pyproject.toml index 6144f6e2..01b44d01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ advanced = [ "scikit-fmm", "spgl1", "dtcwt", + "ucurv", ] [tool.setuptools.packages.find] diff --git a/pytests/test_udct.py b/pytests/test_udct.py new file mode 100644 index 00000000..4f194b20 --- /dev/null +++ b/pytests/test_udct.py @@ -0,0 +1,63 @@ +import numpy as np +import pytest + +# from pylops.signalprocessing import UDCT + +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands + +eps = 1e-6 +shapes = [ + [[256, 256], ], + [[32, 32, 32], ], + [[16, 16, 16, 16], ] +] + +configurations = [ + [[[3, 3]], + [[6, 6]], + [[12, 12]], + [[12, 12], [24, 24]], + [[12, 12], [3, 3], [6, 6]], + [[12, 12], [3, 3], [6, 6], [24, 24]]], + [[[3, 3, 3]], + [[6, 6, 6]], + [[12, 12, 12]], + [[12, 12, 12], [24, 24, 24]]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + + [[[3, 3, 3, 3]]], + # [[6, 6, 6, 6]], + # [[12, 12, 12, 12]], + # [[12, 12, 12, 12], [24, 24, 24, 24]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], +] + +combinations = [ + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list] + + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_ucurv(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert are_close + + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_vectorize(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + flat = bands2vec(band) + unflat = vec2bands(flat, tf) + recon = ucurvinv(unflat, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert are_close diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ce1fb00..6a9d1d22 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,6 +11,7 @@ scikit-fmm sympy devito dtcwt +ucurv matplotlib ipython pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index 74fea77d..0b0b381e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -13,6 +13,7 @@ scikit-fmm sympy devito dtcwt +ucurv matplotlib ipython pytest