Skip to content

Commit a9ace80

Browse files
committed
Add Angular Spectrum wave propagation module
1 parent fd572a2 commit a9ace80

File tree

7 files changed

+840
-2
lines changed

7 files changed

+840
-2
lines changed

docs/examples/Tutorial_Wave_Propagation_ASM_Development.ipynb

Lines changed: 549 additions & 0 deletions
Large diffs are not rendered by default.

optiland/backend/numpy_backend.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def fftconvolve(
187187
return _fftconvolve(a, b, mode=mode)
188188

189189

190+
def fftfreq(n: int, d: float = 1.0) -> NDArray:
191+
return np.fft.fftfreq(n, d=d)
192+
193+
194+
def fft2(x: ArrayLike) -> NDArray:
195+
return np.fft.fft2(array(x))
196+
197+
198+
def ifft2(x: ArrayLike) -> NDArray:
199+
return np.fft.ifft2(array(x))
200+
201+
190202
def grid_sample(
191203
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
192204
):
@@ -265,3 +277,32 @@ def polyval(coeffs: ArrayLike, x: ArrayLike) -> NDArray:
265277
x: A number or an array of numbers at which to evaluate.
266278
"""
267279
return np.polyval(coeffs, x)
280+
281+
282+
def meshgrid(*arrays: ArrayLike, indexing: str = "xy"):
283+
return np.meshgrid(*arrays, indexing=indexing)
284+
285+
286+
def pad(
287+
x: NDArray,
288+
pad_width: tuple[tuple[int, int], tuple[int, int]],
289+
mode: Literal["constant"] = "constant",
290+
constant_values: float = 0.0,
291+
) -> NDArray:
292+
(pt, pb), (pl, pr) = pad_width
293+
pads = [(0, 0)] * x.ndim
294+
pads[-2] = (pt, pb)
295+
pads[-1] = (pl, pr)
296+
return np.pad(x, pads, mode=mode, constant_values=constant_values)
297+
298+
299+
def clamp(x: ArrayLike, min: float | None = None, max: float | None = None) -> NDArray:
300+
return np.clip(array(x), a_min=min, a_max=max)
301+
302+
303+
def broadcast_to(x: ArrayLike, shape) -> NDArray:
304+
return np.broadcast_to(array(x), shape)
305+
306+
307+
def zeros(shape) -> NDArray:
308+
return np.zeros(shape, dtype=complex)

optiland/backend/torch_backend.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,10 @@ def cast(x: ArrayLike) -> Tensor:
309309
return x.to(device=get_device(), dtype=get_precision())
310310

311311

312+
def clamp(x: Tensor, min: float | None = None, max: float | None = None) -> Tensor:
313+
return torch.clamp(x, min=min, max=max)
314+
315+
312316
def copy(x: Tensor) -> Tensor:
313317
return x.clone()
314318

@@ -357,8 +361,8 @@ def flip(x: Tensor) -> Tensor:
357361
return torch.flip(x, dims=(0,))
358362

359363

360-
def meshgrid(*arrays: Tensor) -> tuple[Tensor, ...]:
361-
return torch.meshgrid(*arrays, indexing="xy")
364+
def meshgrid(*arrays: Tensor, indexing: str = "xy") -> tuple[Tensor, ...]:
365+
return torch.meshgrid(*arrays, indexing=indexing)
362366

363367

364368
def roll(
@@ -865,6 +869,18 @@ def eye(n: int) -> Tensor:
865869
# --------------------------
866870
# Signal Processing
867871
# --------------------------
872+
def fftfreq(n: int, d: float = 1.0) -> Tensor:
873+
return torch.fft.fftfreq(n, d=d, device=get_device(), dtype=get_precision())
874+
875+
876+
def fft2(x: Tensor) -> Tensor:
877+
return torch.fft.fft2(x)
878+
879+
880+
def ifft2(x: Tensor) -> Tensor:
881+
return torch.fft.ifft2(x)
882+
883+
868884
def fftconvolve(
869885
in1: Tensor, in2: Tensor, mode: Literal["full", "valid", "same"] = "full"
870886
) -> Tensor:
@@ -997,6 +1013,7 @@ def grid_sample(
9971013
"load",
9981014
# Utilities
9991015
"cast",
1016+
"clamp",
10001017
"copy",
10011018
"is_array_like",
10021019
"size",
@@ -1063,4 +1080,8 @@ def grid_sample(
10631080
# Simulation
10641081
"fftconvolve",
10651082
"grid_sample",
1083+
# FFT
1084+
"fftfreq",
1085+
"fft2",
1086+
"ifft2",
10661087
]

optiland/optic/optic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SurfaceSagViewer,
4343
)
4444
from optiland.wavelength import WavelengthGroup
45+
from optiland.wavepropagation import BaseWavePropagator
4546

4647
if TYPE_CHECKING:
4748
from matplotlib.axes import Axes
@@ -129,6 +130,7 @@ def _initialize_attributes(self):
129130
self.fields: FieldGroup = FieldGroup()
130131
self.wavelengths: WavelengthGroup = WavelengthGroup()
131132

133+
self.wave_propagator: BaseWavePropagator = BaseWavePropagator(self)
132134
self.paraxial: Paraxial = Paraxial(self)
133135
self.aberrations: Aberrations = Aberrations(self)
134136
self.ray_tracer: RealRayTracer = RealRayTracer(self)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# flake8: noqa
2+
3+
from .base import BaseWavePropagator
4+
from .asm import AngularSpectrumPropagator

optiland/wavepropagation/asm.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import optiland.backend as be
2+
from math import pi
3+
4+
5+
class AngularSpectrumPropagator:
6+
def __init__(
7+
self,
8+
num_points: int,
9+
dx: float,
10+
evanescent: str = "clamp",
11+
):
12+
self.dx = float(dx)
13+
self.evanescent = evanescent
14+
M_orig, N_orig = (num_points, num_points)
15+
self.Mpad = int(round(M_orig * 0.5))
16+
self.Npad = int(round(N_orig * 0.5))
17+
M = M_orig + 2 * self.Mpad
18+
N = N_orig + 2 * self.Npad
19+
fx = be.fftfreq(N, d=self.dx)
20+
fy = be.fftfreq(M, d=self.dx)
21+
FY, FX = be.meshgrid(fy, fx, indexing="ij")
22+
self.fx2_fy2 = (FX * FX + FY * FY)[None, None]
23+
24+
def __call__(self, input_field, distance, wavelengths):
25+
return self.forward(input_field, distance, wavelengths)
26+
27+
def forward(self, input_field, distance, wavelengths):
28+
if getattr(input_field, "ndim", None) != 4:
29+
raise ValueError("input_field must have shape (B, C, M, N).")
30+
31+
B = input_field.shape[0]
32+
distance = be.array(distance).reshape(-1, 1, 1, 1)
33+
if distance.shape[0] == 1 and B > 1:
34+
distance = be.broadcast_to(distance, (B, 1, 1, 1))
35+
wavelengths = be.array(wavelengths).reshape(-1, 1, 1, 1)
36+
if wavelengths.shape[0] == 1 and B > 1:
37+
wavelengths = be.broadcast_to(wavelengths, (B, 1, 1, 1))
38+
39+
k = 2.0 * pi / (wavelengths * 1e-3)
40+
41+
if self.Mpad > 0 or self.Npad > 0:
42+
padded = be.pad(
43+
input_field, ((self.Mpad, self.Mpad), (self.Npad, self.Npad))
44+
)
45+
else:
46+
padded = input_field
47+
48+
spectrum = be.fft2(padded)
49+
argument = k * k - (2.0 * pi) ** 2 * self.fx2_fy2
50+
51+
if self.evanescent == "clamp":
52+
kz = be.sqrt(be.clamp(argument, min=0.0))
53+
H = be.exp(1j * be.to_complex(kz) * be.to_complex(distance))
54+
elif self.evanescent == "decay":
55+
kz = be.sqrt(be.to_complex(argument))
56+
H = be.exp(1j * kz * be.to_complex(distance))
57+
else:
58+
raise ValueError('evanescent must be "clamp" or "decay".')
59+
60+
out = be.ifft2(spectrum * H)
61+
62+
if self.Mpad > 0:
63+
out = out[:, :, self.Mpad : -self.Mpad, :]
64+
if self.Npad > 0:
65+
out = out[:, :, :, self.Npad : -self.Npad]
66+
67+
return out

optiland/wavepropagation/base.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from __future__ import annotations
2+
3+
import warnings
4+
from math import pi
5+
from typing import TYPE_CHECKING
6+
7+
import optiland.backend as be
8+
from optiland.utils import resolve_wavelengths
9+
10+
if TYPE_CHECKING:
11+
from optiland.optic import Optic
12+
13+
14+
class BaseWavePropagator:
15+
def __init__(self, optic: Optic):
16+
self.optic = optic
17+
18+
def _build_propagator(self, num_points: int, dx: float):
19+
from optiland.wavepropagation.asm import AngularSpectrumPropagator as Propagator
20+
21+
return Propagator(num_points, dx)
22+
23+
def compute_surface_phase(self, surf, X, Y, wl_arr):
24+
interaction = getattr(surf, "interaction_model", None)
25+
if interaction is None:
26+
raise ValueError("Surface has no interaction_model")
27+
28+
interaction_type = getattr(interaction, "interaction_type", None)
29+
if interaction_type not in ("phase", "refractive_reflective"):
30+
raise NotImplementedError(
31+
f"Unsupported interaction_type: {interaction_type}"
32+
)
33+
34+
if interaction_type == "phase":
35+
phase_profile = getattr(interaction, "phase_profile", None)
36+
if phase_profile is None:
37+
raise ValueError("Phase surface has no phase_profile")
38+
phi = be.array(phase_profile.get_phase(X, Y, wl_arr))
39+
else:
40+
n = surf.material_post.n(wl_arr) if hasattr(surf, "material_post") else 1.0
41+
sag = be.array(surf.geometry.sag(X, Y))
42+
phi = (
43+
(2 * pi / (wl_arr * 1e-3))[:, None, None]
44+
* (n - 1.0)[:, None, None]
45+
* sag[None, :, :]
46+
)
47+
48+
if getattr(phi, "ndim", None) == 2:
49+
phi = phi[None]
50+
51+
dphi_x = be.max(be.abs(phi[:, 1:] - phi[:, :-1]))
52+
dphi_y = be.max(be.abs(phi[:, :, 1:] - phi[:, :, :-1]))
53+
dphi_max = be.max(be.array([dphi_x, dphi_y]))
54+
55+
if dphi_max > be.pi:
56+
warnings.warn(
57+
f"Phase under-sampled: phase jump > π between pixels. Max phase jump = {dphi_max:.2f} rad.",
58+
RuntimeWarning,
59+
stacklevel=2,
60+
)
61+
62+
phase = be.exp(1j * phi)
63+
if interaction_type != "phase":
64+
phase = phase.conj()
65+
return phase
66+
67+
def create_input_field(self, X, Y, wl_arr, field, w0: float | str | None = "auto"):
68+
angle_x, angle_y = field
69+
70+
k = (2 * pi / (wl_arr * 1e-3))[:, None, None]
71+
ax = be.array(float(angle_x))
72+
ay = be.array(float(angle_y))
73+
74+
phase = k * (X[None] * be.sin(ax) + Y[None] * be.sin(ay))
75+
field = be.exp(1j * phase)
76+
77+
if w0 is not None:
78+
if w0 == "auto":
79+
dx = float(X[0, 1] - X[0, 0])
80+
w0 = ((X.shape[0] * dx) / 2) / 2
81+
r2 = X**2 + Y**2
82+
field = field * be.exp(-(r2 / (w0**2)))[None]
83+
84+
return field
85+
86+
def compute_field(
87+
self,
88+
z_target: float,
89+
num_points: int,
90+
dx: float,
91+
field: list[tuple[float, float]] | tuple[float, float] | None = None,
92+
wavelengths: str | float | list = "primary",
93+
beam_waist: float | str | None = "auto",
94+
):
95+
if field is None:
96+
fields = [(0.0, 0.0)]
97+
elif isinstance(field, tuple) and len(field) == 2:
98+
fields = [field]
99+
else:
100+
fields = list(field)
101+
102+
F = len(fields)
103+
if F == 0:
104+
raise ValueError("No fields provided.")
105+
106+
wavelengths_resolved = resolve_wavelengths(self.optic, wavelengths)
107+
wl_arr = be.array([float(w) for w in wavelengths_resolved])
108+
W = int(wl_arr.shape[0])
109+
if W == 0:
110+
raise ValueError("No wavelengths resolved.")
111+
112+
x = be.linspace(-(num_points // 2) * dx, (num_points // 2) * dx, num_points)
113+
y = be.copy(x)
114+
Y, X = be.meshgrid(y, x, indexing="ij")
115+
116+
field_arr = be.zeros((F, W, num_points, num_points)) + 0j
117+
for i, f in enumerate(fields):
118+
field_arr[i] = self.create_input_field(
119+
X=X, Y=Y, wl_arr=wl_arr, field=f, w0=beam_waist
120+
)
121+
122+
propagator = self._build_propagator(num_points, dx)
123+
current_z = 0.0
124+
wl_batch = wl_arr.repeat(F)
125+
126+
def propagate(field_arr, dist: float):
127+
flat = field_arr.reshape(F * W, num_points, num_points)
128+
flat = propagator(flat[:, None], dist, wl_batch)[:, 0]
129+
return flat.reshape(F, W, num_points, num_points)
130+
131+
for surf in self.optic.surface_group.surfaces:
132+
phase = self.compute_surface_phase(surf, X, Y, wl_arr)
133+
field_arr = field_arr * phase[None]
134+
135+
if getattr(surf, "aperture", None):
136+
aperture = be.array(surf.aperture.contains(X, Y))
137+
field_arr = field_arr * aperture[None, None]
138+
139+
t = surf.thickness if surf.thickness != float("inf") else 0.0
140+
141+
if current_z + t >= z_target:
142+
remaining = z_target - current_z
143+
if remaining > 0:
144+
field_arr = propagate(field_arr, remaining)
145+
return field_arr
146+
147+
if t > 0:
148+
field_arr = propagate(field_arr, t)
149+
current_z += t
150+
151+
if z_target > current_z:
152+
field_arr = propagate(field_arr, z_target - current_z)
153+
154+
return field_arr

0 commit comments

Comments
 (0)