Skip to content

Commit d3d8ec0

Browse files
Feat/mtf vs field (#488)
* feat: add draft MTF vs. field analysis type * feat: optimize calculation for speed. Patch in tests for speed.
1 parent ebd2de2 commit d3d8ec0

File tree

6 files changed

+436
-0
lines changed

6 files changed

+436
-0
lines changed

docs/api/api_analysis.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ encircled energy, field curvature, distortion, etc.
1919
analysis.image_simulation
2020
analysis.irradiance
2121
analysis.jones_pupil
22+
analysis.mtf_vs_field
2223
analysis.pupil_aberration
2324
analysis.ray_fan
2425
analysis.rms_vs_field

docs/gallery/opd_psf_mtf.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ This section contains examples of the analysis of the optical path difference (O
1717
wavefront/mtf_geometric
1818
wavefront/mtf_fft
1919
wavefront/mtf_huygens
20+
wavefront/mtf_vs_field
2021
wavefront/through_focus_mtf
2122
wavefront/sampled_mtf
2223
wavefront/zernike_decomposition

docs/gallery/wavefront/mtf_vs_field.ipynb

Lines changed: 96 additions & 0 deletions
Large diffs are not rendered by default.

optiland/analysis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .pupil_aberration import PupilAberration
1313
from .irradiance import IncoherentIrradiance
1414
from .intensity import RadiantIntensity
15+
from .mtf_vs_field import MTFvsField
1516
from .through_focus_mtf import ThroughFocusMTF
1617
from .through_focus_spot_diagram import ThroughFocusSpotDiagram
1718
from .jones_pupil import JonesPupil

optiland/analysis/mtf_vs_field.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""MTF versus Field Analysis
2+
3+
This module enables the calculation of the Modulation Transfer Function (MTF)
4+
versus field coordinate of an optical system.
5+
6+
Kramer Harrison, 2026
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import TYPE_CHECKING
12+
13+
import matplotlib.pyplot as plt
14+
15+
import optiland.backend as be
16+
from optiland.analysis.base import BaseAnalysis
17+
from optiland.mtf.sampled import SampledMTF
18+
19+
if TYPE_CHECKING:
20+
from matplotlib.axes import Axes
21+
from matplotlib.figure import Figure
22+
23+
from optiland.optic import Optic
24+
25+
26+
class MTFvsField(BaseAnalysis):
27+
"""MTF versus Field Coordinate.
28+
29+
This class is used to analyze the Modulation Transfer Function (MTF) versus
30+
the field coordinate of an optical system for specified spatial frequencies.
31+
32+
Args:
33+
optic (Optic): the optical system.
34+
frequencies (list[float]): the spatial frequencies (in cycles/mm) to analyze.
35+
num_fields (int): the number of fields in the Y direction. Default is 32.
36+
wavelengths (str or list): the wavelengths to be analyzed. Default is 'all'.
37+
num_rays (int): the number of rays across the pupil in 1D for the SampledMTF
38+
calculation. Default is 128.
39+
override_limits (bool): If True, bypasses the limit on the number of frequencies
40+
and wavelengths to prevent cluttered plots. Default is False.
41+
42+
"""
43+
44+
MAX_FREQUENCIES = 5
45+
MAX_WAVELENGTHS = 3
46+
47+
def __init__(
48+
self,
49+
optic: Optic,
50+
frequencies: list[float],
51+
num_fields: int = 32,
52+
wavelengths: str | list[float] = "all",
53+
num_rays: int = 128,
54+
override_limits: bool = False,
55+
):
56+
self.frequencies = frequencies
57+
self.num_fields = num_fields
58+
self.num_rays = num_rays
59+
60+
self._check_limits(override_limits, wavelengths, optic)
61+
62+
# Base Analysis will set self.wavelengths
63+
super().__init__(optic, wavelengths)
64+
65+
def _check_limits(self, override_limits: bool, wavelengths, optic):
66+
"""Check to ensure inputs won't produce an overly cluttered plot."""
67+
if override_limits:
68+
return
69+
70+
if len(self.frequencies) > self.MAX_FREQUENCIES:
71+
raise ValueError(
72+
f"Number of frequencies ({len(self.frequencies)}) exceeds the "
73+
f"recommended limit of {self.MAX_FREQUENCIES} for clean plots. "
74+
"Set override_limits=True to bypass this check."
75+
)
76+
77+
from optiland.utils import resolve_wavelengths
78+
79+
resolved_wls = resolve_wavelengths(optic, wavelengths)
80+
num_wl = len(resolved_wls)
81+
82+
if num_wl > self.MAX_WAVELENGTHS:
83+
raise ValueError(
84+
f"Number of wavelengths ({num_wl}) exceeds the recommended "
85+
f"limit of {self.MAX_WAVELENGTHS} for clean plots. "
86+
"Set override_limits=True to bypass this check."
87+
)
88+
89+
def _generate_data(self):
90+
"""Generate the MTF data across fields, wavelengths, and frequencies."""
91+
fields = [(0.0, float(Hy)) for Hy in be.linspace(0.0, 1.0, self.num_fields)]
92+
self._field_coords = be.array(fields)
93+
94+
# Pre-build list of frequencies to calculate at once
95+
freqs_to_calc = []
96+
for freq in self.frequencies:
97+
freqs_to_calc.append((freq, 0.0))
98+
freqs_to_calc.append((0.0, freq))
99+
100+
results = []
101+
for wl in self.wavelengths:
102+
wl_results = [{"tangential": [], "sagittal": []} for _ in self.frequencies]
103+
104+
for field in fields:
105+
sampled_mtf = SampledMTF(
106+
optic=self.optic,
107+
field=field,
108+
wavelength=wl,
109+
num_rays=self.num_rays,
110+
distribution="uniform",
111+
zernike_terms=37,
112+
zernike_type="fringe",
113+
)
114+
115+
mtfs = sampled_mtf.calculate_mtf(freqs_to_calc)
116+
117+
for i_freq in range(len(self.frequencies)):
118+
wl_results[i_freq]["tangential"].append(mtfs[2 * i_freq])
119+
wl_results[i_freq]["sagittal"].append(mtfs[2 * i_freq + 1])
120+
121+
for i_freq in range(len(self.frequencies)):
122+
wl_results[i_freq]["tangential"] = be.array(
123+
wl_results[i_freq]["tangential"]
124+
)
125+
wl_results[i_freq]["sagittal"] = be.array(
126+
wl_results[i_freq]["sagittal"]
127+
)
128+
129+
results.append(wl_results)
130+
131+
return results
132+
133+
def view(
134+
self,
135+
fig_to_plot_on: Figure | None = None,
136+
figsize: tuple[float, float] = (8, 5),
137+
) -> tuple[Figure, Axes]:
138+
"""
139+
Plots the MTF versus the field coordinate for each frequency and wavelength.
140+
141+
Args:
142+
fig_to_plot_on (Figure, optional): An existing matplotlib Figure to
143+
plot on. If provided, the plot will be embedded in this figure.
144+
If None (default), a new figure will be created.
145+
figsize (tuple[float, float], optional): Size of the figure to create
146+
if `fig_to_plot_on` is None. Defaults to (8, 5).
147+
148+
Returns:
149+
tuple[Figure, Axes]: The matplotlib Figure and Axes objects
150+
containing the plot.
151+
"""
152+
is_gui_embedding = fig_to_plot_on is not None
153+
154+
if is_gui_embedding:
155+
current_fig = fig_to_plot_on
156+
current_fig.clear()
157+
ax = current_fig.add_subplot(111)
158+
else:
159+
current_fig, ax = plt.subplots(figsize=figsize)
160+
161+
max_field = float(self.optic.fields.max_field)
162+
y_coords_normalized = be.to_numpy(self._field_coords[:, 1])
163+
x_plot = y_coords_normalized * max_field
164+
165+
# Determine X-axis label
166+
field_def = self.optic.field_definition
167+
x_label = "Field Coordinate"
168+
if field_def is not None:
169+
field_name = field_def.__class__.__name__
170+
if "Angle" in field_name:
171+
x_label = "Angle (deg)"
172+
elif "Height" in field_name:
173+
x_label = "Height (mm)"
174+
else:
175+
# Fallback if no specific type is set but fields exist
176+
x_label = "Field Coordinate"
177+
178+
axes_color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
179+
180+
for i_wl, wavelength in enumerate(self.wavelengths):
181+
for i_freq, freq in enumerate(self.frequencies):
182+
color_idx = (i_wl * len(self.frequencies) + i_freq) % len(
183+
axes_color_cycle
184+
)
185+
color = axes_color_cycle[color_idx]
186+
187+
tan_data = be.to_numpy(self.data[i_wl][i_freq]["tangential"])
188+
sag_data = be.to_numpy(self.data[i_wl][i_freq]["sagittal"])
189+
190+
label_prefix = f"{freq} cyc/mm"
191+
if len(self.wavelengths) > 1:
192+
label_prefix += f", {wavelength:.4f} µm"
193+
194+
ax.plot(
195+
x_plot,
196+
tan_data,
197+
linestyle="-",
198+
color=color,
199+
label=f"{label_prefix} (Tan)",
200+
)
201+
ax.plot(
202+
x_plot,
203+
sag_data,
204+
linestyle="--",
205+
color=color,
206+
label=f"{label_prefix} (Sag)",
207+
)
208+
209+
ax.set_xlabel(x_label)
210+
ax.set_ylabel("Modulus of the OTF")
211+
ax.legend(bbox_to_anchor=(1.05, 0.5), loc="center left")
212+
213+
if max_field > 0:
214+
ax.set_xlim(0, max_field)
215+
216+
ax.set_ylim(0, 1.05)
217+
ax.grid(True, linestyle=":", alpha=0.5)
218+
current_fig.tight_layout()
219+
220+
if is_gui_embedding and hasattr(current_fig, "canvas"):
221+
current_fig.canvas.draw_idle()
222+
223+
return current_fig, ax

tests/test_mtf_vs_field.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pytest
2+
import matplotlib.pyplot as plt
3+
from matplotlib.axes import Axes
4+
from matplotlib.figure import Figure
5+
6+
from optiland.analysis import MTFvsField
7+
from optiland.samples.objectives import CookeTriplet
8+
import optiland.backend as be
9+
import numpy as np
10+
from unittest.mock import patch
11+
12+
13+
class TestMTFvsField:
14+
"""Test suite for the MTFvsField class."""
15+
16+
@patch('optiland.analysis.MTFvsField._generate_data')
17+
def test_init_defaults(self, mock_generate_data, set_test_backend):
18+
"""Test initialization with default parameters."""
19+
mock_generate_data.return_value = [
20+
[
21+
{"tangential": np.zeros(32), "sagittal": np.zeros(32)},
22+
{"tangential": np.zeros(32), "sagittal": np.zeros(32)}
23+
] for _ in range(3)
24+
]
25+
optic = CookeTriplet()
26+
mtf_vf = MTFvsField(optic, frequencies=[10.0, 20.0])
27+
28+
assert mtf_vf.frequencies == [10.0, 20.0]
29+
assert mtf_vf.num_fields == 32
30+
assert mtf_vf.num_rays == 128
31+
assert mtf_vf.wavelengths == [w.value for w in optic.wavelengths.wavelengths]
32+
33+
# Check results structure
34+
assert len(mtf_vf.data) == len(mtf_vf.wavelengths)
35+
assert len(mtf_vf.data[0]) == len(mtf_vf.frequencies)
36+
37+
# Check tangential and sagittal arrays
38+
res = mtf_vf.data[0][0]
39+
assert "tangential" in res
40+
assert "sagittal" in res
41+
assert len(res["tangential"]) == mtf_vf.num_fields
42+
assert len(res["sagittal"]) == mtf_vf.num_fields
43+
44+
def test_init_limits_frequencies(self, set_test_backend):
45+
"""Test the frequency limit checks."""
46+
optic = CookeTriplet()
47+
freqs = [10, 20, 30, 40, 50, 60] # length 6 > MAX=5
48+
49+
with pytest.raises(ValueError, match="Number of frequencies"):
50+
MTFvsField(optic, frequencies=freqs)
51+
52+
# Try with override
53+
with patch('optiland.analysis.MTFvsField._generate_data'):
54+
mtf_vf = MTFvsField(optic, frequencies=freqs, num_rays=16, num_fields=2, override_limits=True)
55+
assert len(mtf_vf.frequencies) == 6
56+
57+
def test_init_limits_wavelengths(self, set_test_backend):
58+
"""Test the frequency limit checks."""
59+
optic = CookeTriplet()
60+
optic.add_wavelength(0.5, weight=1.0)
61+
optic.add_wavelength(0.6, weight=1.0)
62+
optic.add_wavelength(0.7, weight=1.0)
63+
# Optic now has 4 wavelengths (>3 MAX)
64+
65+
with pytest.raises(ValueError, match="Number of wavelengths"):
66+
MTFvsField(optic, frequencies=[10], wavelengths="all")
67+
68+
# Try with override
69+
with patch('optiland.analysis.MTFvsField._generate_data'):
70+
mtf_vf = MTFvsField(optic, frequencies=[10], wavelengths="all", num_rays=16, num_fields=2, override_limits=True)
71+
assert len(mtf_vf.wavelengths) == 6
72+
73+
def test_custom_params(self, set_test_backend):
74+
"""Test initialization with custom parameters."""
75+
optic = CookeTriplet()
76+
mtf_vf = MTFvsField(
77+
optic,
78+
frequencies=[15.0],
79+
num_fields=3,
80+
wavelengths=[0.55],
81+
num_rays=16
82+
)
83+
84+
assert mtf_vf.frequencies == [15.0]
85+
assert mtf_vf.num_fields == 3
86+
assert mtf_vf.num_rays == 16
87+
assert mtf_vf.wavelengths == [0.55]
88+
89+
# Check data
90+
res = mtf_vf.data[0][0]
91+
assert len(res["tangential"]) == 3
92+
93+
# Check values
94+
assert be.all(res["tangential"] >= 0.0)
95+
assert be.all(res["tangential"] <= 1.0)
96+
assert be.all(res["sagittal"] >= 0.0)
97+
assert be.all(res["sagittal"] <= 1.0)
98+
99+
def test_view(self, set_test_backend):
100+
"""Test view method to ensure it creates plots properly."""
101+
optic = CookeTriplet()
102+
mtf_vf = MTFvsField(optic, frequencies=[10.0], num_fields=2, num_rays=16)
103+
fig, ax = mtf_vf.view()
104+
105+
assert fig is not None
106+
assert ax is not None
107+
assert isinstance(fig, Figure)
108+
assert isinstance(ax, Axes)
109+
110+
# Check label is present based on field definition in CookeTriplet
111+
assert ax.get_xlabel() in ["Angle (deg)", "Height (mm)", "Field Coordinate"]
112+
assert ax.get_ylabel() == "Modulus of the OTF"
113+
114+
plt.close(fig)

0 commit comments

Comments
 (0)