Skip to content

Commit 1e4e674

Browse files
authored
Feature/mnist (#47)
* Refactor UDCT and related components for improved clarity and functionality - Updated the UDCT class to support both angular_wedges_config and num_scales/wedges_per_direction parameters for enhanced flexibility in configuration. - Introduced MUDCTCoefficients type for monogenic transforms, improving type safety and clarity in the codebase. - Enhanced error messages for invalid angular wedge configurations, providing clearer guidance for users. - Updated documentation to reflect changes in parameter handling and added examples for better usability. - Improved test coverage for UDCT and MUDCT functionalities, ensuring robustness and correctness across various configurations. * Enhance consistency and clarity in NumPy and PyTorch module exports - Added pylint comments to suppress duplicate code warnings in both NumPy and PyTorch __init__.py files, indicating intentional duplication for shared API. - Updated the __all__ list in both modules to ensure consistent API exposure, aligning with the shared functionality between NumPy and PyTorch implementations. - Improved documentation comments to clarify the rationale behind the similar exports, enhancing maintainability and understanding of the code structure. * Update project configuration and add new example for UDCT - Added 'data/' directory to .gitignore to exclude data files from version control. - Updated pyproject.toml to include 'scikit-learn' as a dependency for enhanced functionality. - Enhanced documentation requirements by adding 'torchvision' to docs/requirements.txt for improved image processing capabilities. - Introduced a new example script 'plot_10_udct_mnist.py' demonstrating the use of UDCT for image classification on the MNIST dataset, showcasing the model architecture and training process. * Enhance type checking and update example for UDCT - Updated `pyproject.toml` to include `torchvision` in mypy overrides for improved type checking. - Modified `docs/requirements.txt` to ensure `torchvision` is included for documentation builds. - Refactored `plot_10_udct_mnist.py` to implement a two-layer MLP for classification, improving model architecture clarity and performance. - Updated comments and documentation within the example script for better understanding and maintainability. * Enhance UDCT example and documentation - Updated the `plot_10_udct_mnist.py` script to return both average loss and accuracy during training and testing, improving performance metrics tracking. - Enhanced the visualization section to include separate subplots for training/testing loss and accuracy, providing clearer insights into model performance. - Added `:no-index:` directive in `curvelets.torch` documentation to prevent indexing, streamlining documentation navigation. * Enhance UDCT documentation and refactor code - Added documentation for the `UDCTModule` class in `curvelets.torch`, improving clarity on its usage and functionality. - Refactored the `_UDCTFunction` class to remove the unused `transform_type` parameter, streamlining the API and enhancing code clarity. - Updated the plotting script `plot_10_udct_mnist.py` for better readability by formatting plot commands across multiple lines. * Update training configuration in UDCT MNIST example - Increased the number of training epochs from 2 to 10 in `plot_10_udct_mnist.py` to enhance model performance. - Updated comments and section headers for improved clarity and organization throughout the script. * Remove outdated comments from UDCT MNIST example script - Deleted unnecessary comments and results section in `plot_10_udct_mnist.py` to streamline the code and improve readability. - Focused on enhancing the clarity of the plotting functionality without extraneous information. * Add sphinx_gallery_thumbnail_number comments to example scripts - Introduced `sphinx_gallery_thumbnail_number` comments in multiple example scripts to enhance documentation and improve visual representation in the gallery. - Updated `plot_01_zone_plate.py`, `plot_02_direction_resolution.py`, `plot_03_direction_disk.py`, `plot_04_meyer_wavelet.py`, `plot_05_curvelet_vs_wavelet.py`, and `plot_10_udct_mnist.py` to include these comments for better organization and clarity in the generated documentation. * Implement tensor transfer functionality in UDCTModule and UDCT - Added the `_apply` method in `UDCTModule` to facilitate the application of functions to internal tensors during device transfers. - Introduced `apply_to_tensors` method in `UDCT` for applying functions to all internal tensors, ensuring proper handling during device and dtype changes. - Created unit tests for the `_apply` method in `test_udct_module_apply.py` to verify correct tensor dtype transfers and functionality after conversions. * Update window_threshold default value in UDCTModule from 1e-6 to 1e-5 - Modified the default value of the `window_threshold` parameter in the `UDCTModule` class to improve performance and storage efficiency. - Ensured consistency in parameter documentation to reflect the updated default value.
1 parent c34c846 commit 1e4e674

21 files changed

+1077
-137
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ sg_execution_times.rst
44
playground/
55
.cursor/
66
references
7+
data/
78

89
# Byte-compiled / optimized / DLL files
910
__pycache__/

docs/curvelet_faqs.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ There are three flavors of the discrete curvelet transform with publicly availab
3333
As of 2026, any non-academic use of the CurveLab Toolbox requires a commercial license. Any library which ports or converts Curvelab code to another language is also subject to Curvelab's license.
3434
While this does not include libraries which wrap the CurveLab toolbox and therefore do not contain any source code of Curvelab, their usage still requires Curvelab and therefore its license. Such wrappers include `curvelops <https://github.com/PyLops/curvelops>`_, `PyCurvelab <https://github.com/slimgroup/PyCurvelab>`_, both MIT licensed.
3535

36-
A third flavor is the **Uniform Discrete Curvelet Transform (UDCT)** which does not have the same restrictive license as the FDCT. The UDCT was first implemented in MATLAB (`ucurvmd <https://github.com/nttruong7/ucurvmd>`_ [dead link]) by one of its authors, Truong Nguyen.
36+
A third flavor is the **Uniform Discrete Curvelet Transform (UDCT)** which does not have the same restrictive license as the FDCT. The UDCT was first implemented in MATLAB (`ucurvmd <https://github.com/nttruong7/ucurvmd>`_ [dead link]) by one of its authors, Truong Nguyen.
3737

3838
**This library provides the first open-source, pure-Python implementation of the UDCT**, borrowing heavily from Nguyen's original implementation. The goal of this library is to allow industry professionals to use curvelets more easily. It also goes beyond the original implementation by providing a the support for complex signals, monogenic extension for real signals :cite:`Storath2010`, and a wavelet transform at the highest scale.
3939

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
# This ensures we use the smaller CPU-only build instead of the CUDA build
33
--index-url https://download.pytorch.org/whl/cpu
44
torch
5+
torchvision

docs/source/curvelets.torch.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,8 @@ curvelets.torch package
55
:members:
66
:show-inheritance:
77
:undoc-members:
8+
:no-index:
9+
10+
.. autoclass:: curvelets.torch.UDCTModule
11+
:members:
12+
:show-inheritance:

examples/plot_01_zone_plate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
==========
44
"""
55

6+
# sphinx_gallery_thumbnail_number = 2
7+
68
from __future__ import annotations
79

810
# %%
@@ -43,7 +45,6 @@
4345
ax.set(title="Input")
4446

4547
# %%
46-
# sphinx_gallery_thumbnail_number = 2
4748
fig, ax = plt.subplots(figsize=(4, 3))
4849
im = ax.imshow(zone_plate_inv.T, **opts)
4950
_, cb = create_colorbar(im=im, ax=ax)

examples/plot_02_direction_resolution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
fewer.
88
"""
99

10+
# sphinx_gallery_thumbnail_number = 2
11+
1012
from __future__ import annotations
1113

1214
import matplotlib.pyplot as plt
@@ -211,7 +213,6 @@ def plot_colorbars(
211213
C = C_sym
212214
colored_wins = color_windows(C, cmaps_dir=cmaps, color_low=color_low, color_bg=color_bg)
213215

214-
# sphinx_gallery_thumbnail_number = 2
215216
fig = plt.figure(layout="constrained")
216217
fig.suptitle(title)
217218
gs = GridSpec(3, 2, figure=fig, height_ratios=[8, 1, 1])

examples/plot_03_direction_disk.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
`Kymatio's Scattering disks <https://www.kymat.io/gallery_2d/plot_scattering_disk.html>`__.
77
"""
88

9+
# sphinx_gallery_thumbnail_number = 2
10+
911
from __future__ import annotations
1012

1113
import matplotlib.pyplot as plt
@@ -62,7 +64,6 @@
6264
energy_c = apply_along_wedges(d_c, lambda w, *_: np.sqrt((np.abs(w) ** 2).mean()))
6365

6466
# %%
65-
# sphinx_gallery_thumbnail_number = 2
6667
fig, ax = plt.subplots(figsize=(12, figsize_aspect * 8))
6768
ax.imshow(data.T, vmin=-vmax, vmax=vmax, **opts_space)
6869
overlay_arrows(kvecs, ax, arrowprops={"edgecolor": "w", "facecolor": "k"})

examples/plot_04_meyer_wavelet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
reconstruction of the original signal.
99
"""
1010

11+
# sphinx_gallery_thumbnail_number = 4
12+
1113
from __future__ import annotations
1214

1315
# %%
@@ -135,7 +137,6 @@
135137
ky = fftshift(fftfreq(ny))
136138

137139
# Create figure with 4 subplots
138-
# sphinx_gallery_thumbnail_number = 4
139140
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
140141
axs = axs.flatten()
141142

examples/plot_05_curvelet_vs_wavelet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
showing both frequency-domain windows and spatial coefficients.
88
"""
99

10+
# sphinx_gallery_thumbnail_number = 2
11+
1012
from __future__ import annotations
1113

1214
# %%
@@ -112,7 +114,6 @@
112114
}
113115

114116
# Plot curvelet and wavelet windows
115-
# sphinx_gallery_thumbnail_number = 2
116117
n_curvelet_windows = len(curvelet_windows_scale1)
117118
n_wavelet_windows = len(wavelet_windows_scale1)
118119
n_cols = max(n_curvelet_windows, n_wavelet_windows)

examples/plot_07_monogenic_verification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
# Reproducing Formula
6868
# ###################
6969
#
70-
# According to :cite:`Storath2010`, the monogenic curvelet transform should satisfy
70+
# According to :cite:t:`Storath2010`, the monogenic curvelet transform should satisfy
7171
# the reproducing formula:
7272
#
7373
# .. math::

0 commit comments

Comments
 (0)