Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
549 changes: 549 additions & 0 deletions docs/examples/Tutorial_Wave_Propagation_ASM_Development.ipynb

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions optiland/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def fftconvolve(
return _fftconvolve(a, b, mode=mode)


def fftfreq(n: int, d: float = 1.0) -> NDArray:
return np.fft.fftfreq(n, d=d)


def fft2(x: ArrayLike) -> NDArray:
return np.fft.fft2(array(x))


def ifft2(x: ArrayLike) -> NDArray:
return np.fft.ifft2(array(x))


def grid_sample(
input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
):
Expand Down Expand Up @@ -265,3 +277,32 @@ def polyval(coeffs: ArrayLike, x: ArrayLike) -> NDArray:
x: A number or an array of numbers at which to evaluate.
"""
return np.polyval(coeffs, x)


def meshgrid(*arrays: ArrayLike, indexing: str = "xy"):
return np.meshgrid(*arrays, indexing=indexing)


def pad(
x: NDArray,
pad_width: tuple[tuple[int, int], tuple[int, int]],
mode: Literal["constant"] = "constant",
constant_values: float = 0.0,
) -> NDArray:
(pt, pb), (pl, pr) = pad_width
pads = [(0, 0)] * x.ndim
pads[-2] = (pt, pb)
pads[-1] = (pl, pr)
return np.pad(x, pads, mode=mode, constant_values=constant_values)


def clamp(x: ArrayLike, min: float | None = None, max: float | None = None) -> NDArray:
return np.clip(array(x), a_min=min, a_max=max)


def broadcast_to(x: ArrayLike, shape) -> NDArray:
return np.broadcast_to(array(x), shape)


def zeros(shape) -> NDArray:
return np.zeros(shape, dtype=complex)
25 changes: 23 additions & 2 deletions optiland/backend/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ def cast(x: ArrayLike) -> Tensor:
return x.to(device=get_device(), dtype=get_precision())


def clamp(x: Tensor, min: float | None = None, max: float | None = None) -> Tensor:
return torch.clamp(x, min=min, max=max)


def copy(x: Tensor) -> Tensor:
return x.clone()

Expand Down Expand Up @@ -357,8 +361,8 @@ def flip(x: Tensor) -> Tensor:
return torch.flip(x, dims=(0,))


def meshgrid(*arrays: Tensor) -> tuple[Tensor, ...]:
return torch.meshgrid(*arrays, indexing="xy")
def meshgrid(*arrays: Tensor, indexing: str = "xy") -> tuple[Tensor, ...]:
return torch.meshgrid(*arrays, indexing=indexing)


def roll(
Expand Down Expand Up @@ -865,6 +869,18 @@ def eye(n: int) -> Tensor:
# --------------------------
# Signal Processing
# --------------------------
def fftfreq(n: int, d: float = 1.0) -> Tensor:
return torch.fft.fftfreq(n, d=d, device=get_device(), dtype=get_precision())


def fft2(x: Tensor) -> Tensor:
return torch.fft.fft2(x)


def ifft2(x: Tensor) -> Tensor:
return torch.fft.ifft2(x)


def fftconvolve(
in1: Tensor, in2: Tensor, mode: Literal["full", "valid", "same"] = "full"
) -> Tensor:
Expand Down Expand Up @@ -997,6 +1013,7 @@ def grid_sample(
"load",
# Utilities
"cast",
"clamp",
"copy",
"is_array_like",
"size",
Expand Down Expand Up @@ -1063,4 +1080,8 @@ def grid_sample(
# Simulation
"fftconvolve",
"grid_sample",
# FFT
"fftfreq",
"fft2",
"ifft2",
]
2 changes: 2 additions & 0 deletions optiland/optic/optic.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SurfaceSagViewer,
)
from optiland.wavelength import WavelengthGroup
from optiland.wavepropagation import BaseWavePropagator

if TYPE_CHECKING:
from matplotlib.axes import Axes
Expand Down Expand Up @@ -129,6 +130,7 @@ def _initialize_attributes(self):
self.fields: FieldGroup = FieldGroup()
self.wavelengths: WavelengthGroup = WavelengthGroup()

self.wave_propagator: BaseWavePropagator = BaseWavePropagator(self)
self.paraxial: Paraxial = Paraxial(self)
self.aberrations: Aberrations = Aberrations(self)
self.ray_tracer: RealRayTracer = RealRayTracer(self)
Expand Down
4 changes: 4 additions & 0 deletions optiland/wavepropagation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# flake8: noqa

from .base import BaseWavePropagator
from .asm import AngularSpectrumPropagator
67 changes: 67 additions & 0 deletions optiland/wavepropagation/asm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import optiland.backend as be
from math import pi


class AngularSpectrumPropagator:
def __init__(
self,
num_points: int,
dx: float,
evanescent: str = "clamp",
):
self.dx = float(dx)
self.evanescent = evanescent
M_orig, N_orig = (num_points, num_points)
self.Mpad = int(round(M_orig * 0.5))
self.Npad = int(round(N_orig * 0.5))
M = M_orig + 2 * self.Mpad
N = N_orig + 2 * self.Npad
fx = be.fftfreq(N, d=self.dx)
fy = be.fftfreq(M, d=self.dx)
FY, FX = be.meshgrid(fy, fx, indexing="ij")
self.fx2_fy2 = (FX * FX + FY * FY)[None, None]

def __call__(self, input_field, distance, wavelengths):
return self.forward(input_field, distance, wavelengths)

def forward(self, input_field, distance, wavelengths):
if getattr(input_field, "ndim", None) != 4:
raise ValueError("input_field must have shape (B, C, M, N).")

B = input_field.shape[0]
distance = be.array(distance).reshape(-1, 1, 1, 1)
if distance.shape[0] == 1 and B > 1:
distance = be.broadcast_to(distance, (B, 1, 1, 1))
wavelengths = be.array(wavelengths).reshape(-1, 1, 1, 1)
if wavelengths.shape[0] == 1 and B > 1:
wavelengths = be.broadcast_to(wavelengths, (B, 1, 1, 1))

k = 2.0 * pi / (wavelengths * 1e-3)

if self.Mpad > 0 or self.Npad > 0:
padded = be.pad(
input_field, ((self.Mpad, self.Mpad), (self.Npad, self.Npad))
)
else:
padded = input_field

spectrum = be.fft2(padded)
argument = k * k - (2.0 * pi) ** 2 * self.fx2_fy2

if self.evanescent == "clamp":
kz = be.sqrt(be.clamp(argument, min=0.0))
H = be.exp(1j * be.to_complex(kz) * be.to_complex(distance))
elif self.evanescent == "decay":
kz = be.sqrt(be.to_complex(argument))
H = be.exp(1j * kz * be.to_complex(distance))
else:
raise ValueError('evanescent must be "clamp" or "decay".')

out = be.ifft2(spectrum * H)

if self.Mpad > 0:
out = out[:, :, self.Mpad : -self.Mpad, :]
if self.Npad > 0:
out = out[:, :, :, self.Npad : -self.Npad]

return out
154 changes: 154 additions & 0 deletions optiland/wavepropagation/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

import warnings
from math import pi
from typing import TYPE_CHECKING

import optiland.backend as be
from optiland.utils import resolve_wavelengths

if TYPE_CHECKING:
from optiland.optic import Optic


class BaseWavePropagator:
def __init__(self, optic: Optic):
self.optic = optic

def _build_propagator(self, num_points: int, dx: float):
from optiland.wavepropagation.asm import AngularSpectrumPropagator as Propagator

return Propagator(num_points, dx)

def compute_surface_phase(self, surf, X, Y, wl_arr):
interaction = getattr(surf, "interaction_model", None)
if interaction is None:
raise ValueError("Surface has no interaction_model")

interaction_type = getattr(interaction, "interaction_type", None)
if interaction_type not in ("phase", "refractive_reflective"):
raise NotImplementedError(
f"Unsupported interaction_type: {interaction_type}"
)

if interaction_type == "phase":
phase_profile = getattr(interaction, "phase_profile", None)
if phase_profile is None:
raise ValueError("Phase surface has no phase_profile")
phi = be.array(phase_profile.get_phase(X, Y, wl_arr))
else:
n = surf.material_post.n(wl_arr) if hasattr(surf, "material_post") else 1.0
sag = be.array(surf.geometry.sag(X, Y))
phi = (
(2 * pi / (wl_arr * 1e-3))[:, None, None]
* (n - 1.0)[:, None, None]
* sag[None, :, :]
)

if getattr(phi, "ndim", None) == 2:
phi = phi[None]

dphi_x = be.max(be.abs(phi[:, 1:] - phi[:, :-1]))
dphi_y = be.max(be.abs(phi[:, :, 1:] - phi[:, :, :-1]))
dphi_max = be.max(be.array([dphi_x, dphi_y]))

if dphi_max > be.pi:
warnings.warn(
f"Phase under-sampled: phase jump > π between pixels. Max phase jump = {dphi_max:.2f} rad.",
RuntimeWarning,
stacklevel=2,
)

phase = be.exp(1j * phi)
if interaction_type != "phase":
phase = phase.conj()
return phase

def create_input_field(self, X, Y, wl_arr, field, w0: float | str | None = "auto"):
angle_x, angle_y = field

k = (2 * pi / (wl_arr * 1e-3))[:, None, None]
ax = be.array(float(angle_x))
ay = be.array(float(angle_y))

phase = k * (X[None] * be.sin(ax) + Y[None] * be.sin(ay))
field = be.exp(1j * phase)

if w0 is not None:
if w0 == "auto":
dx = float(X[0, 1] - X[0, 0])
w0 = ((X.shape[0] * dx) / 2) / 2
r2 = X**2 + Y**2
field = field * be.exp(-(r2 / (w0**2)))[None]

return field

def compute_field(
self,
z_target: float,
num_points: int,
dx: float,
field: list[tuple[float, float]] | tuple[float, float] | None = None,
wavelengths: str | float | list = "primary",
beam_waist: float | str | None = "auto",
):
if field is None:
fields = [(0.0, 0.0)]
elif isinstance(field, tuple) and len(field) == 2:
fields = [field]
else:
fields = list(field)

F = len(fields)
if F == 0:
raise ValueError("No fields provided.")

wavelengths_resolved = resolve_wavelengths(self.optic, wavelengths)
wl_arr = be.array([float(w) for w in wavelengths_resolved])
W = int(wl_arr.shape[0])
if W == 0:
raise ValueError("No wavelengths resolved.")

x = be.linspace(-(num_points // 2) * dx, (num_points // 2) * dx, num_points)
y = be.copy(x)
Y, X = be.meshgrid(y, x, indexing="ij")

field_arr = be.zeros((F, W, num_points, num_points)) + 0j
for i, f in enumerate(fields):
field_arr[i] = self.create_input_field(
X=X, Y=Y, wl_arr=wl_arr, field=f, w0=beam_waist
)

propagator = self._build_propagator(num_points, dx)
current_z = 0.0
wl_batch = wl_arr.repeat(F)

def propagate(field_arr, dist: float):
flat = field_arr.reshape(F * W, num_points, num_points)
flat = propagator(flat[:, None], dist, wl_batch)[:, 0]
return flat.reshape(F, W, num_points, num_points)

for surf in self.optic.surface_group.surfaces:
phase = self.compute_surface_phase(surf, X, Y, wl_arr)
field_arr = field_arr * phase[None]

if getattr(surf, "aperture", None):
aperture = be.array(surf.aperture.contains(X, Y))
field_arr = field_arr * aperture[None, None]

t = surf.thickness if surf.thickness != float("inf") else 0.0

if current_z + t >= z_target:
remaining = z_target - current_z
if remaining > 0:
field_arr = propagate(field_arr, remaining)
return field_arr

if t > 0:
field_arr = propagate(field_arr, t)
current_z += t

if z_target > current_z:
field_arr = propagate(field_arr, z_target - current_z)

return field_arr
Loading