From b99e8f6eaec8c589b351286cb06ddb503ebbdf8b Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Wed, 28 Aug 2024 15:55:51 +0100 Subject: [PATCH 1/7] First commit, basic operator and test --- pylops/signalprocessing/__init__.py | 5 +-- pylops/signalprocessing/udct.py | 32 ++++++++++++++ pytests/test_udct.py | 67 +++++++++++++++++++++++++++++ requirements-dev.txt | 6 +-- 4 files changed, 103 insertions(+), 7 deletions(-) create mode 100644 pylops/signalprocessing/udct.py create mode 100644 pytests/test_udct.py diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index a8e5ed65..b3fa0881 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -23,7 +23,6 @@ Shift Fractional Shift operator. DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. - DWTND N-dimensional Wavelet operator. DCT Discrete Cosine Transform. DTCWT Dual-Tree Complex Wavelet Transform. Radon2D Two dimensional Radon transform. @@ -35,6 +34,7 @@ Patch2D 2D Patching transform operator. Patch3D 3D Patching transform operator. Fredholm1 Fredholm integral of first kind. + UDCT Uniform Discrete Curvelet Transform """ @@ -62,10 +62,10 @@ from .fredholm1 import * from .dwt import * from .dwt2d import * -from .dwtnd import * from .seislet import * from .dct import * from .dtcwt import * +from .udct import * __all__ = [ @@ -95,7 +95,6 @@ "Fredholm1", "DWT", "DWT2D", - "DWTND", "Seislet", "DCT", "DTCWT", diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py new file mode 100644 index 00000000..6ddf4986 --- /dev/null +++ b/pylops/signalprocessing/udct.py @@ -0,0 +1,32 @@ +__all__ = ["UDCT"] + +from typing import Any, NewType, Union + +import numpy as np + +from pylops import LinearOperator +from pylops.utils import deps +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.decorators import reshaped +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray +from ucurv import * + +class UDCT(LinearOperator): + def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): + self.udct = udct(sz, cfg, complex, sparse) + self.shape = (self.udct.len, np.prod(sz)) + self.dtype = np.dtype(dtype) + self.explicit = False + self.rmatvec_count = 0 + self.matvec_count = 0 + def _matvec(self, x): + img = x.reshape(self.udct.sz) + band = ucurvfwd(img, self.udct) + bvec = bands2vec(band) + return bvec + + def _rmatvec(self, x): + band = vec2bands(x, self.udct) + recon = ucurvinv(band, self.udct) + recon2 = recon.reshape(self.udct.sz) + return recon2 \ No newline at end of file diff --git a/pytests/test_udct.py b/pytests/test_udct.py new file mode 100644 index 00000000..c3393fe6 --- /dev/null +++ b/pytests/test_udct.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from pylops.signalprocessing import UDCT + +from ucurv import * + +import numpy as np + +eps = 1e-6 +shapes = [ + [[256, 256], ], + [[32, 32, 32], ], + [[16, 16, 16, 16], ] +] + +configurations = [ + [[[3, 3]], + [[6, 6]], + [[12, 12]], + [[12, 12], [24, 24]], + [[12, 12], [3, 3], [6, 6]], + [[12, 12], [3, 3], [6, 6], [24, 24]], + ], + [[[3, 3, 3]], + [[6, 6, 6]], + [[12, 12, 12]], + [[12, 12, 12], [24, 24, 24]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + ], + [[[3, 3, 3, 3]], + # [[6, 6, 6, 6]], + # [[12, 12, 12, 12]], + # [[12, 12, 12, 12], [24, 24, 24, 24]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], + ], +] + +combinations = [ + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list +] + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_ucurv(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_vectorize(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + flat = bands2vec(band) + unflat = vec2bands(flat, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) + diff --git a/requirements-dev.txt b/requirements-dev.txt index 6ce1fb00..cd4b5911 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,6 @@ numpy>=1.21.0 -scipy>=1.11.0 ---extra-index-url https://download.pytorch.org/whl/cpu +scipy>=1.4.0 torch>=1.2.0 -jax numba pyfftw PyWavelets @@ -20,7 +18,6 @@ docutils<0.18 Sphinx pydata-sphinx-theme sphinx-gallery -sphinxemoji numpydoc nbsphinx image @@ -30,3 +27,4 @@ isort black flake8 mypy +ucurv From 3666662e8e2c9b86eb845df7029ebc43979d5761 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Sun, 1 Sep 2024 16:34:06 +0100 Subject: [PATCH 2/7] Add ucurv module to various places --- environment-dev-arm.yml | 7 ++--- environment-dev.yml | 6 ++-- pylops/signalprocessing/udct.py | 2 ++ pylops/utils/deps.py | 55 ++++++++++++--------------------- pyproject.toml | 3 +- 5 files changed, 29 insertions(+), 44 deletions(-) diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 0081a0ad..5ba127d3 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -8,10 +8,8 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.11.0 + - scipy>=1.4.0 - pytorch>=1.2.0 - - cpuonly - - jax - pyfftw - pywavelets - sympy @@ -28,6 +26,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner @@ -35,7 +34,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx - - sphinxemoji - image - flake8 - mypy + \ No newline at end of file diff --git a/environment-dev.yml b/environment-dev.yml index eb51c4dc..24f38424 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -8,10 +8,8 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.11.0 + - scipy>=1.4.0 - pytorch>=1.2.0 - - cpuonly - - jax - pyfftw - pywavelets - sympy @@ -29,6 +27,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner @@ -36,7 +35,6 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx - - sphinxemoji - image - flake8 - mypy diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py index 6ddf4986..8d2ef5b1 100644 --- a/pylops/signalprocessing/udct.py +++ b/pylops/signalprocessing/udct.py @@ -11,6 +11,8 @@ from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray from ucurv import * +ucurv_message = deps.ucurv_import("the ucurv module") + class UDCT(LinearOperator): def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): self.udct = udct(sz, cfg, complex, sparse) diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index ecf69a95..3497ce86 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,8 +1,8 @@ __all__ = [ "cupy_enabled", - "jax_enabled", "devito_enabled", "dtcwt_enabled", + "ucurv_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -52,34 +52,6 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message -def jax_import(message: Optional[str] = None) -> str: - jax_test = ( - util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 - ) - if jax_test: - try: - import_module("jax") # noqa: F401 - - jax_message = None - except (ImportError, ModuleNotFoundError) as e: - jax_message = ( - f"Failed to import jax, Falling back to numpy (error: {e}). " - "Please ensure your environment is set up correctly " - "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" - ) - print(UserWarning(jax_message)) - else: - jax_message = ( - "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. " - f"In order to be able to use {message} " - "ensure 'os.getenv('JAX_PYLOPS') == 1' and run " - "'pip install jax'; " - "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" - ) - - return jax_message - - def devito_import(message: Optional[str] = None) -> str: if devito_enabled: try: @@ -113,6 +85,21 @@ def dtcwt_import(message: Optional[str] = None) -> str: ) return dtcwt_message +def ucurv_import(message: Optional[str] = None) -> str: + if ucurv_enabled: + try: + import ucurv # noqa: F401 + + ucurv_message = None + except Exception as e: + ucurv_message = f"Failed to import ucurv (error:{e})." + else: + ucurv_message = ( + f"UCURV not available. " + f"In order to be able to use " + f'{message} run "pip install ucurv".' + ) + return ucurv_message def numba_import(message: Optional[str] = None) -> str: if numba_enabled: @@ -224,20 +211,18 @@ def sympy_import(message: Optional[str] = None) -> str: # Set package availability booleans -# cupy and jax: the package is imported to check everything is working correctly, -# if not the package is disabled. We do this here as these libraries are used as drop-in -# replacement for many numpy and scipy routines when cupy/jax arrays are provided. +# cupy: the package is imported to check everything is working correctly, +# if not the package is disabled. We do this here as this library is used as drop-in +# replacement for many numpy and scipy routines when cupy arrays are provided. # all other libraries: we simply check if the package is available and postpone its import # to check everything is working correctly when a user tries to create an operator that requires # such a package cupy_enabled: bool = ( True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False ) -jax_enabled: bool = ( - True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False -) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None +ucurv_enabled = util.find_spec("ucurv") is not None numba_enabled = util.find_spec("numba") is not None pyfftw_enabled = util.find_spec("pyfftw") is not None pywt_enabled = util.find_spec("pywt") is not None diff --git a/pyproject.toml b/pyproject.toml index 6144f6e2..5e435cc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ ] dependencies = [ "numpy >= 1.21.0", - "scipy >= 1.11.0", + "scipy >= 1.4.0", ] dynamic = ["version"] @@ -44,6 +44,7 @@ advanced = [ "scikit-fmm", "spgl1", "dtcwt", + "ucurv", ] [tool.setuptools.packages.find] From 934a154448d9e947744c28f977903aeee7524153 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Sun, 1 Sep 2024 21:00:03 +0100 Subject: [PATCH 3/7] existing files changes --- environment-dev-arm.yml | 4 ++++ environment-dev.yml | 4 ++++ pylops/signalprocessing/__init__.py | 2 ++ pylops/utils/deps.py | 17 +++++++++++++++++ pyproject.toml | 1 + requirements-dev.txt | 1 + 6 files changed, 29 insertions(+) diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 7cb73753..439a34e0 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -10,6 +10,8 @@ dependencies: - numpy>=1.21.0 - scipy>=1.4.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -26,6 +28,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner @@ -36,3 +39,4 @@ dependencies: - image - flake8 - mypy + \ No newline at end of file diff --git a/environment-dev.yml b/environment-dev.yml index 59b2c127..f8161474 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -10,6 +10,8 @@ dependencies: - numpy>=1.21.0 - scipy>=1.4.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -27,6 +29,7 @@ dependencies: - pip: - devito - dtcwt + - ucurv - scikit-fmm - spgl1 - pytest-runner @@ -34,6 +37,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index 7137b586..b3fa0881 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -34,6 +34,7 @@ Patch2D 2D Patching transform operator. Patch3D 3D Patching transform operator. Fredholm1 Fredholm integral of first kind. + UDCT Uniform Discrete Curvelet Transform """ @@ -64,6 +65,7 @@ from .seislet import * from .dct import * from .dtcwt import * +from .udct import * __all__ = [ diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 4b2d21e7..3497ce86 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -2,6 +2,7 @@ "cupy_enabled", "devito_enabled", "dtcwt_enabled", + "ucurv_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -84,6 +85,21 @@ def dtcwt_import(message: Optional[str] = None) -> str: ) return dtcwt_message +def ucurv_import(message: Optional[str] = None) -> str: + if ucurv_enabled: + try: + import ucurv # noqa: F401 + + ucurv_message = None + except Exception as e: + ucurv_message = f"Failed to import ucurv (error:{e})." + else: + ucurv_message = ( + f"UCURV not available. " + f"In order to be able to use " + f'{message} run "pip install ucurv".' + ) + return ucurv_message def numba_import(message: Optional[str] = None) -> str: if numba_enabled: @@ -206,6 +222,7 @@ def sympy_import(message: Optional[str] = None) -> str: ) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None +ucurv_enabled = util.find_spec("ucurv") is not None numba_enabled = util.find_spec("numba") is not None pyfftw_enabled = util.find_spec("pyfftw") is not None pywt_enabled = util.find_spec("pywt") is not None diff --git a/pyproject.toml b/pyproject.toml index 9c338a48..5e435cc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ advanced = [ "scikit-fmm", "spgl1", "dtcwt", + "ucurv", ] [tool.setuptools.packages.find] diff --git a/requirements-dev.txt b/requirements-dev.txt index d86f07f1..cd4b5911 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -27,3 +27,4 @@ isort black flake8 mypy +ucurv From b367ea9f65f5bb568778f280fef4442af208bc43 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Sun, 1 Sep 2024 21:05:46 +0100 Subject: [PATCH 4/7] commit main files --- examples/plot_udct.py | 49 ++++++++++++++++++++++++ pylops/signalprocessing/udct.py | 34 +++++++++++++++++ pytests/test_udct.py | 67 +++++++++++++++++++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 examples/plot_udct.py create mode 100644 pylops/signalprocessing/udct.py create mode 100644 pytests/test_udct.py diff --git a/examples/plot_udct.py b/examples/plot_udct.py new file mode 100644 index 00000000..35aac56e --- /dev/null +++ b/examples/plot_udct.py @@ -0,0 +1,49 @@ +""" +Uniform Discrete Curvelet Transform +=================================== +This example shows how to use the :py:class:`pylops.signalprocessing.UDCT` operator to perform the +Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. +""" + + +from ucurv import * +import matplotlib.pyplot as plt +import pylops +plt.close("all") + +if False: + sz = [512, 512] + cfg = [[3, 3], [6,6]] + res = len(cfg) + rsq = zoneplate(sz) + img = rsq - np.mean(rsq) + + transform = udct(sz, cfg, complex = False, high = "curvelet") + + imband = ucurvfwd(img, transform) + plt.figure(figsize = (20, 60)) + print(imband.keys()) + plt.imshow(np.abs(ucurv2d_show(imband, transform))) + # plt.show() + + recon = ucurvinv(imband, transform) + + err = img - recon + print(np.max(np.abs(err))) + plt.figure(figsize = (20, 60)) + plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1))) + + plt.figure() + plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) + # plt.show() + +################################################################################ + + +sz = [256, 256] +cfg = [[3,3],[6,6]] +x = np.random.rand(256*256) +y = np.random.rand(262144) +F = pylops.signalprocessing.UDCT(sz,cfg) +print(np.dot(y,F*x)) +print(np.dot(x,F.T*y)) \ No newline at end of file diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py new file mode 100644 index 00000000..8d2ef5b1 --- /dev/null +++ b/pylops/signalprocessing/udct.py @@ -0,0 +1,34 @@ +__all__ = ["UDCT"] + +from typing import Any, NewType, Union + +import numpy as np + +from pylops import LinearOperator +from pylops.utils import deps +from pylops.utils._internal import _value_or_sized_to_tuple +from pylops.utils.decorators import reshaped +from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray +from ucurv import * + +ucurv_message = deps.ucurv_import("the ucurv module") + +class UDCT(LinearOperator): + def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): + self.udct = udct(sz, cfg, complex, sparse) + self.shape = (self.udct.len, np.prod(sz)) + self.dtype = np.dtype(dtype) + self.explicit = False + self.rmatvec_count = 0 + self.matvec_count = 0 + def _matvec(self, x): + img = x.reshape(self.udct.sz) + band = ucurvfwd(img, self.udct) + bvec = bands2vec(band) + return bvec + + def _rmatvec(self, x): + band = vec2bands(x, self.udct) + recon = ucurvinv(band, self.udct) + recon2 = recon.reshape(self.udct.sz) + return recon2 \ No newline at end of file diff --git a/pytests/test_udct.py b/pytests/test_udct.py new file mode 100644 index 00000000..c3393fe6 --- /dev/null +++ b/pytests/test_udct.py @@ -0,0 +1,67 @@ +import numpy as np +import pytest + +from pylops.signalprocessing import UDCT + +from ucurv import * + +import numpy as np + +eps = 1e-6 +shapes = [ + [[256, 256], ], + [[32, 32, 32], ], + [[16, 16, 16, 16], ] +] + +configurations = [ + [[[3, 3]], + [[6, 6]], + [[12, 12]], + [[12, 12], [24, 24]], + [[12, 12], [3, 3], [6, 6]], + [[12, 12], [3, 3], [6, 6], [24, 24]], + ], + [[[3, 3, 3]], + [[6, 6, 6]], + [[12, 12, 12]], + [[12, 12, 12], [24, 24, 24]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + ], + [[[3, 3, 3, 3]], + # [[6, 6, 6, 6]], + # [[12, 12, 12, 12]], + # [[12, 12, 12, 12], [24, 24, 24, 24]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], + ], +] + +combinations = [ + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list +] + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_ucurv(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) + +@pytest.mark.parametrize("shape, cfg", combinations) +def test_vectorize(shape, cfg): + data = np.random.rand(*shape) + tf = udct(shape, cfg) + band = ucurvfwd(data, tf) + flat = bands2vec(band) + unflat = vec2bands(flat, tf) + recon = ucurvinv(band, tf) + are_close = np.all(np.isclose(data, recon, atol=eps)) + assert(are_close == True) + From 8bb63c0c72b1f77d43ad72f4ce25e80d583bc2c2 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Sun, 1 Sep 2024 21:23:03 +0100 Subject: [PATCH 5/7] Manually add back new code from main branch --- environment-dev-arm.yml | 5 +++- environment-dev.yml | 5 +++- pylops/signalprocessing/__init__.py | 3 +++ pylops/utils/deps.py | 36 ++++++++++++++++++++++++++--- pyproject.toml | 2 +- requirements-dev.txt | 5 +++- 6 files changed, 49 insertions(+), 7 deletions(-) diff --git a/environment-dev-arm.yml b/environment-dev-arm.yml index 5ba127d3..c711fe76 100755 --- a/environment-dev-arm.yml +++ b/environment-dev-arm.yml @@ -8,8 +8,10 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.4.0 + - scipy>=1.11.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -34,6 +36,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/environment-dev.yml b/environment-dev.yml index 24f38424..135319f7 100755 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -8,8 +8,10 @@ dependencies: - python>=3.6.4 - pip - numpy>=1.21.0 - - scipy>=1.4.0 + - scipy>=1.11.0 - pytorch>=1.2.0 + - cpuonly + - jax - pyfftw - pywavelets - sympy @@ -35,6 +37,7 @@ dependencies: - pydata-sphinx-theme - sphinx-gallery - nbsphinx + - sphinxemoji - image - flake8 - mypy diff --git a/pylops/signalprocessing/__init__.py b/pylops/signalprocessing/__init__.py index b3fa0881..b4f8fd7b 100755 --- a/pylops/signalprocessing/__init__.py +++ b/pylops/signalprocessing/__init__.py @@ -23,6 +23,7 @@ Shift Fractional Shift operator. DWT One dimensional Wavelet operator. DWT2D Two dimensional Wavelet operator. + DWTND N-dimensional Wavelet operator. DCT Discrete Cosine Transform. DTCWT Dual-Tree Complex Wavelet Transform. Radon2D Two dimensional Radon transform. @@ -62,6 +63,7 @@ from .fredholm1 import * from .dwt import * from .dwt2d import * +from .dwtnd import * from .seislet import * from .dct import * from .dtcwt import * @@ -95,6 +97,7 @@ "Fredholm1", "DWT", "DWT2D", + "DWTND", "Seislet", "DCT", "DTCWT", diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index 3497ce86..b320028b 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,5 +1,6 @@ __all__ = [ "cupy_enabled", + "jax_enabled", "devito_enabled", "dtcwt_enabled", "ucurv_enabled", @@ -51,6 +52,32 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message +def jax_import(message: Optional[str] = None) -> str: + jax_test = ( + util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 + ) + if jax_test: + try: + import_module("jax") # noqa: F401 + + jax_message = None + except (ImportError, ModuleNotFoundError) as e: + jax_message = ( + f"Failed to import jax, Falling back to numpy (error: {e}). " + "Please ensure your environment is set up correctly " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + print(UserWarning(jax_message)) + else: + jax_message = ( + "Jax package not installed or os.getenv('JAX_PYLOPS') == 0. " + f"In order to be able to use {message} " + "ensure 'os.getenv('JAX_PYLOPS') == 1' and run " + "'pip install jax'; " + "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" + ) + + return jax_message def devito_import(message: Optional[str] = None) -> str: if devito_enabled: @@ -211,15 +238,18 @@ def sympy_import(message: Optional[str] = None) -> str: # Set package availability booleans -# cupy: the package is imported to check everything is working correctly, -# if not the package is disabled. We do this here as this library is used as drop-in -# replacement for many numpy and scipy routines when cupy arrays are provided. +# cupy and jax: the package is imported to check everything is working correctly, +# if not the package is disabled. We do this here as these libraries are used as drop-in +# replacement for many numpy and scipy routines when cupy/jax arrays are provided. # all other libraries: we simply check if the package is available and postpone its import # to check everything is working correctly when a user tries to create an operator that requires # such a package cupy_enabled: bool = ( True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False ) +jax_enabled: bool = ( + True if (jax_import() is None and int(os.getenv("JAX_PYLOPS", 1)) == 1) else False +) devito_enabled = util.find_spec("devito") is not None dtcwt_enabled = util.find_spec("dtcwt") is not None ucurv_enabled = util.find_spec("ucurv") is not None diff --git a/pyproject.toml b/pyproject.toml index 5e435cc7..01b44d01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ classifiers = [ ] dependencies = [ "numpy >= 1.21.0", - "scipy >= 1.4.0", + "scipy >= 1.11.0", ] dynamic = ["version"] diff --git a/requirements-dev.txt b/requirements-dev.txt index cd4b5911..42477b9a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,8 @@ numpy>=1.21.0 -scipy>=1.4.0 +scipy>=1.11.0 +--extra-index-url https://download.pytorch.org/whl/cpu torch>=1.2.0 +jax numba pyfftw PyWavelets @@ -18,6 +20,7 @@ docutils<0.18 Sphinx pydata-sphinx-theme sphinx-gallery +sphinxemoji numpydoc nbsphinx image From f3a792764076366cde427636ca1bc3f9ce7054e2 Mon Sep 17 00:00:00 2001 From: Duy Nguyen Date: Tue, 3 Sep 2024 16:43:42 +0100 Subject: [PATCH 6/7] flake8 correction --- examples/plot_udct.py | 53 ++++++++++++++++----------------- pylops/signalprocessing/udct.py | 51 ++++++++++++++++++++++++------- pylops/utils/deps.py | 9 ++++-- pytests/test_udct.py | 42 ++++++++++++-------------- 4 files changed, 91 insertions(+), 64 deletions(-) diff --git a/examples/plot_udct.py b/examples/plot_udct.py index 35aac56e..cf92035a 100644 --- a/examples/plot_udct.py +++ b/examples/plot_udct.py @@ -5,45 +5,44 @@ Uniform Discrete Curvelet Transform on a (multi-dimensional) input array. """ - -from ucurv import * +import numpy as np +from ucurv import udct, zoneplate, ucurvfwd, ucurvinv, ucurv2d_show import matplotlib.pyplot as plt import pylops plt.close("all") -if False: - sz = [512, 512] - cfg = [[3, 3], [6,6]] - res = len(cfg) - rsq = zoneplate(sz) - img = rsq - np.mean(rsq) +sz = [512, 512] +cfg = [[3, 3], [6, 6]] +res = len(cfg) +rsq = zoneplate(sz) +img = rsq - np.mean(rsq) - transform = udct(sz, cfg, complex = False, high = "curvelet") +transform = udct(sz, cfg, complex=False, high="curvelet") - imband = ucurvfwd(img, transform) - plt.figure(figsize = (20, 60)) - print(imband.keys()) - plt.imshow(np.abs(ucurv2d_show(imband, transform))) - # plt.show() +imband = ucurvfwd(img, transform) +plt.figure(figsize=(20, 60)) +print(imband.keys()) +plt.imshow(np.abs(ucurv2d_show(imband, transform))) +# plt.show() - recon = ucurvinv(imband, transform) +recon = ucurvinv(imband, transform) - err = img - recon - print(np.max(np.abs(err))) - plt.figure(figsize = (20, 60)) - plt.imshow(np.real(np.concatenate((img, recon, err), axis = 1))) +err = img - recon +print(np.max(np.abs(err))) +plt.figure(figsize=(20, 60)) +plt.imshow(np.real(np.concatenate((img, recon, err), axis=1))) - plt.figure() - plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) - # plt.show() +plt.figure() +plt.imshow(np.abs(np.fft.fftshift(np.fft.fftn(err)))) +# plt.show() ################################################################################ sz = [256, 256] -cfg = [[3,3],[6,6]] -x = np.random.rand(256*256) +cfg = [[3, 3], [6, 6]] +x = np.random.rand(256 * 256) y = np.random.rand(262144) -F = pylops.signalprocessing.UDCT(sz,cfg) -print(np.dot(y,F*x)) -print(np.dot(x,F.T*y)) \ No newline at end of file +F = pylops.signalprocessing.UDCT(sz, cfg) +print(np.dot(y, F * x)) +print(np.dot(x, F.T * y)) diff --git a/pylops/signalprocessing/udct.py b/pylops/signalprocessing/udct.py index 8d2ef5b1..ccafb5db 100644 --- a/pylops/signalprocessing/udct.py +++ b/pylops/signalprocessing/udct.py @@ -1,34 +1,63 @@ __all__ = ["UDCT"] -from typing import Any, NewType, Union - import numpy as np from pylops import LinearOperator from pylops.utils import deps -from pylops.utils._internal import _value_or_sized_to_tuple from pylops.utils.decorators import reshaped -from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray -from ucurv import * +from pylops.utils.typing import NDArray +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands ucurv_message = deps.ucurv_import("the ucurv module") + class UDCT(LinearOperator): - def __init__(self, sz, cfg, complex = False, sparse = False, dtype=None): + r"""Uniform Discrete Curvelet Transform + + Perform the multidimensional discrete curvelet transforms + + The UDCT operator is a wraparound of the ucurvfwd and ucurvinv + calls in the UCURV package. Refer to + https://ucurv.readthedocs.io for a detailed description of the + input parameters. + + Parameters + ---------- + udct : :obj:`DTypeLike`, optional + Type of elements in input array. + dtype : :obj:`DTypeLike`, optional + Type of elements in input array. + name : :obj:`str`, optional + Name of operator (to be used by :func:`pylops.utils.describe.describe`) + + Notes + ----- + The UDCT operator applies the uniform discrete curvelet transform + in forward and adjoint modes from the ``ucurv`` library. + + The ``ucurv`` library uses a udct object to represent all the parameters + of the multidimensional transform. The udct object have to be created with the size + of the data need to be transformed, and the cfg parameter which control the + number of resolution and direction. + """ + def __init__(self, sz, cfg, complex=False, sparse=False, dtype=None): self.udct = udct(sz, cfg, complex, sparse) self.shape = (self.udct.len, np.prod(sz)) self.dtype = np.dtype(dtype) - self.explicit = False + self.explicit = False self.rmatvec_count = 0 self.matvec_count = 0 - def _matvec(self, x): + + @reshaped + def _matvec(self, x: NDArray) -> NDArray: img = x.reshape(self.udct.sz) band = ucurvfwd(img, self.udct) - bvec = bands2vec(band) + bvec = bands2vec(band) return bvec - def _rmatvec(self, x): + @reshaped + def _rmatvec(self, x: NDArray) -> NDArray: band = vec2bands(x, self.udct) recon = ucurvinv(band, self.udct) recon2 = recon.reshape(self.udct.sz) - return recon2 \ No newline at end of file + return recon2 diff --git a/pylops/utils/deps.py b/pylops/utils/deps.py index b320028b..744b5aa3 100644 --- a/pylops/utils/deps.py +++ b/pylops/utils/deps.py @@ -1,9 +1,9 @@ __all__ = [ "cupy_enabled", - "jax_enabled", + "jax_enabled", "devito_enabled", "dtcwt_enabled", - "ucurv_enabled", + "ucurv_enabled", "numba_enabled", "pyfftw_enabled", "pywt_enabled", @@ -52,6 +52,7 @@ def cupy_import(message: Optional[str] = None) -> str: return cupy_message + def jax_import(message: Optional[str] = None) -> str: jax_test = ( util.find_spec("jax") is not None and int(os.getenv("JAX_PYLOPS", 1)) == 1 @@ -76,9 +77,9 @@ def jax_import(message: Optional[str] = None) -> str: "'pip install jax'; " "for more details visit 'https://jax.readthedocs.io/en/latest/installation.html'" ) - return jax_message + def devito_import(message: Optional[str] = None) -> str: if devito_enabled: try: @@ -112,6 +113,7 @@ def dtcwt_import(message: Optional[str] = None) -> str: ) return dtcwt_message + def ucurv_import(message: Optional[str] = None) -> str: if ucurv_enabled: try: @@ -128,6 +130,7 @@ def ucurv_import(message: Optional[str] = None) -> str: ) return ucurv_message + def numba_import(message: Optional[str] = None) -> str: if numba_enabled: try: diff --git a/pytests/test_udct.py b/pytests/test_udct.py index c3393fe6..4f194b20 100644 --- a/pytests/test_udct.py +++ b/pytests/test_udct.py @@ -1,11 +1,9 @@ import numpy as np import pytest -from pylops.signalprocessing import UDCT +# from pylops.signalprocessing import UDCT -from ucurv import * - -import numpy as np +from ucurv import udct, ucurvfwd, ucurvinv, bands2vec, vec2bands eps = 1e-6 shapes = [ @@ -15,35 +13,33 @@ ] configurations = [ - [[[3, 3]], + [[[3, 3]], [[6, 6]], [[12, 12]], [[12, 12], [24, 24]], [[12, 12], [3, 3], [6, 6]], - [[12, 12], [3, 3], [6, 6], [24, 24]], - ], - [[[3, 3, 3]], + [[12, 12], [3, 3], [6, 6], [24, 24]]], + [[[3, 3, 3]], [[6, 6, 6]], [[12, 12, 12]], - [[12, 12, 12], [24, 24, 24]], + [[12, 12, 12], [24, 24, 24]]], # [[12, 12, 12], [3, 3, 3], [6, 6, 6]], - # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], - ], - [[[3, 3, 3, 3]], + # [[12, 12, 12], [3, 3, 3], [6, 6, 6], [12, 24, 24]], + + [[[3, 3, 3, 3]]], # [[6, 6, 6, 6]], # [[12, 12, 12, 12]], # [[12, 12, 12, 12], [24, 24, 24, 24]], # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6]], - # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], - ], + # [[12, 12, 12, 12], [3, 3, 3, 3], [6, 6, 6, 6], [12, 24, 24, 24]], ] combinations = [ - (shape, config) - for shape_list, config_list in zip(shapes, configurations) - for shape in shape_list - for config in config_list -] + (shape, config) + for shape_list, config_list in zip(shapes, configurations) + for shape in shape_list + for config in config_list] + @pytest.mark.parametrize("shape, cfg", combinations) def test_ucurv(shape, cfg): @@ -52,7 +48,8 @@ def test_ucurv(shape, cfg): band = ucurvfwd(data, tf) recon = ucurvinv(band, tf) are_close = np.all(np.isclose(data, recon, atol=eps)) - assert(are_close == True) + assert are_close + @pytest.mark.parametrize("shape, cfg", combinations) def test_vectorize(shape, cfg): @@ -61,7 +58,6 @@ def test_vectorize(shape, cfg): band = ucurvfwd(data, tf) flat = bands2vec(band) unflat = vec2bands(flat, tf) - recon = ucurvinv(band, tf) + recon = ucurvinv(unflat, tf) are_close = np.all(np.isclose(data, recon, atol=eps)) - assert(are_close == True) - + assert are_close From 79b91287055dbd0608b35040f20a1ba5f0dfe7bb Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 6 Sep 2024 21:11:35 +0300 Subject: [PATCH 7/7] doc: added ucurv to requirements-doc.txt --- requirements-dev.txt | 2 +- requirements-doc.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 42477b9a..6a9d1d22 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,6 +11,7 @@ scikit-fmm sympy devito dtcwt +ucurv matplotlib ipython pytest @@ -30,4 +31,3 @@ isort black flake8 mypy -ucurv diff --git a/requirements-doc.txt b/requirements-doc.txt index 74fea77d..0b0b381e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -13,6 +13,7 @@ scikit-fmm sympy devito dtcwt +ucurv matplotlib ipython pytest