Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions optiland/materials/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ def _create_cache_key(self, wavelength: float | be.ndarray, **kwargs) -> tuple:
wavelength_key = wavelength
return (wavelength_key,) + tuple(sorted(kwargs.items()))

@staticmethod
def _requires_grad(value) -> bool:
"""Check if a value is a torch tensor that requires gradient."""
return hasattr(value, "requires_grad") and value.requires_grad

@staticmethod
def _detach_if_tensor(value):
"""Detach a torch tensor to sever the computation graph link.

This prevents the 'backward through the graph a second time' error
that occurs when a cached tensor still references a freed computation
graph.
"""
if hasattr(value, "detach"):
return value.detach()
return value

def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
"""Calculates the refractive index at a given wavelength with caching.

Expand All @@ -95,8 +112,17 @@ def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
return self._n_cache[cache_key]

result = self._calculate_n(wavelength, **kwargs)
self._n_cache[cache_key] = result
return result

# If the result requires grad, it is connected to an optimization
# variable (e.g. the index itself is being optimized). In that case
# we must NOT cache — every forward pass needs a fresh graph.
if self._requires_grad(result):
return result

# Otherwise the value is a constant w.r.t. optimization variables.
# Detach before caching to avoid holding a stale computation graph.
self._n_cache[cache_key] = self._detach_if_tensor(result)
return self._n_cache[cache_key]

def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
"""Calculates the extinction coefficient at a given wavelength with caching.
Expand All @@ -115,8 +141,12 @@ def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
return self._k_cache[cache_key]

result = self._calculate_k(wavelength, **kwargs)
self._k_cache[cache_key] = result
return result
# Same logic as n(): skip cache if result is differentiable.
if self._requires_grad(result):
return result

self._k_cache[cache_key] = self._detach_if_tensor(result)
return self._k_cache[cache_key]

@abstractmethod
def _calculate_n(
Expand Down
8 changes: 5 additions & 3 deletions optiland/optimization/optimizer/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def __init__(self, problem: OptimizationProblem):
be.grad_mode.enable()

# Initialize parameters as torch.nn.Parameter objects
initial_params = [var.variable.get_value() for var in self.problem.variables]
# Use var.value (scaled) to match the scaled bounds from var.bounds
initial_params = [var.value for var in self.problem.variables]
self.params = [torch.nn.Parameter(be.array(p)) for p in initial_params]

@abstractmethod
Expand Down Expand Up @@ -118,8 +119,9 @@ def optimize(
optimizer.zero_grad()

# 1. Update the model state from the current nn.Parameter values.
# Use var.update() which inverse-scales from scaled space
for k, param in enumerate(self.params):
self.problem.variables[k].variable.update_value(param)
self.problem.variables[k].update(param)

# 2. Update any dependent properties.
self.problem.update_optics()
Expand Down Expand Up @@ -147,7 +149,7 @@ def optimize(

# Final update to ensure the model reflects the last optimized state
for k, param in enumerate(self.params):
self.problem.variables[k].variable.update_value(param)
self.problem.variables[k].update(param)
self.problem.update_optics()

final_loss = self.problem.sum_squared().item()
Expand Down
10 changes: 10 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,16 @@ def test_view_with_all_zero_intensities(self, tf_spot):


def read_zmx_file(file_path, skip_lines, cols=(0, 1)):
import os
if not os.path.exists(file_path):
this_dir = os.path.dirname(os.path.abspath(__file__))
if file_path.startswith("tests/") or file_path.startswith("tests\\"):
candidate = os.path.join(this_dir, file_path[6:])
else:
candidate = os.path.join(this_dir, file_path)
if os.path.exists(candidate):
file_path = candidate

try:
data = np.loadtxt(
file_path, skiprows=skip_lines, usecols=cols, encoding="utf-16"
Expand Down
164 changes: 164 additions & 0 deletions tests/test_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,170 @@ def _calculate_k(self, wavelength, **kwargs):
# and a cache hit
material.n(wavelength_torch_2, temperature=25)
assert material._calculate_n.call_count == 3
def test_detach_if_tensor_plain_value(self, set_test_backend):
"""_detach_if_tensor returns plain values unchanged."""
assert BaseMaterial._detach_if_tensor(1.5) == 1.5
assert BaseMaterial._detach_if_tensor(None) is None

def test_detach_if_tensor_numpy(self, set_test_backend):
"""_detach_if_tensor returns numpy arrays unchanged."""
arr = np.array([1.5, 1.6])
result = BaseMaterial._detach_if_tensor(arr)
np.testing.assert_array_equal(result, arr)

def test_requires_grad_plain_value(self, set_test_backend):
"""_requires_grad returns False for plain values."""
assert BaseMaterial._requires_grad(1.5) is False
assert BaseMaterial._requires_grad(np.array(1.5)) is False


@pytest.mark.skipif(
"torch" not in be.list_available_backends(),
reason="PyTorch not installed",
)
class TestBaseMaterialTorchCaching:
"""Tests for material caching behavior specific to torch backend.

These tests verify:
- Cached tensors are detached (no stale computation graph)
- Repeated forward+backward passes don't raise RuntimeError
- Cache is bypassed when the result requires grad (optimizable index)
"""

@pytest.fixture(autouse=True)
def _setup_torch(self):
be.set_backend("torch")
be.set_device("cpu")
be.grad_mode.enable()
be.set_precision("float64")
yield
be.set_backend("numpy")

def test_detach_if_tensor_torch(self):
"""_detach_if_tensor detaches torch tensors."""
import torch

t = torch.tensor(1.5, requires_grad=True) * 2.0 # has grad_fn
result = BaseMaterial._detach_if_tensor(t)
assert not result.requires_grad
assert result.grad_fn is None

def test_requires_grad_torch(self):
"""_requires_grad correctly detects torch tensors with grad."""
import torch

leaf = torch.tensor(1.5, requires_grad=True)
assert BaseMaterial._requires_grad(leaf) is True

no_grad = torch.tensor(1.5)
assert BaseMaterial._requires_grad(no_grad) is False

def test_cached_value_is_detached_when_no_grad(self):
"""Cached n/k values must be detached when they don't require grad.

When grad_mode is off, Sellmeier formula results are torch tensors
without requires_grad. These should be cached and detached to prevent
stale computation graph references if grad_mode is later re-enabled.
"""
import torch

be.grad_mode.disable()
try:
mat = materials.Material("N-BK7")
result = mat.n(0.5876)

# Should be cached since result doesn't require grad
assert len(mat._n_cache) > 0, "Result should be cached"

cached = list(mat._n_cache.values())[0]
if isinstance(cached, torch.Tensor):
assert cached.grad_fn is None, (
"Cached tensor should be detached (no grad_fn)"
)
finally:
be.grad_mode.enable()

def test_cache_bypassed_for_real_material_under_grad_mode(self):
"""Under torch + grad_mode, real material results require_grad=True
because be.array() creates grad-enabled tensors. The cache must be
bypassed in this case.
"""
mat = materials.Material("N-BK7")
mat.n(0.5876)

assert len(mat._n_cache) == 0, (
"Cache should be empty — Sellmeier result requires_grad under "
"grad_mode, so caching must be skipped"
)

def test_no_runtime_error_on_repeated_backward(self):
"""Repeated forward+backward passes must not raise RuntimeError.

This is the regression test for the 'backward through the graph a
second time' error. We simulate a multi-step optimization by calling
material.n() repeatedly and computing a loss that flows through it.
"""
import torch

mat = materials.Material("N-BK7")

# Simulate multiple optimizer steps
param = torch.nn.Parameter(torch.tensor(0.55, dtype=torch.float64))

for _ in range(3):
# Forward: compute n at a wavelength that depends on param
n_val = mat.n(param)
loss = (n_val - 1.5) ** 2
loss.backward()

with torch.no_grad():
param.data -= 0.01 * param.grad
param.grad.zero_()

def test_cache_bypassed_when_result_requires_grad(self):
"""n() must NOT cache when the result requires_grad.

If the refractive index itself is an optimization variable (e.g.
IdealMaterial with a nn.Parameter index), caching would break
gradient flow — grad(loss, n) would always be zero.
"""
import torch

# Create an IdealMaterial whose index is a nn.Parameter
mat = materials.IdealMaterial(n=1.5)
mat.index = torch.nn.Parameter(torch.tensor([1.5], dtype=torch.float64))

# First call — result requires grad, should NOT be cached
result1 = mat.n(0.55)
assert result1.requires_grad, "Result should require grad"
assert len(mat._n_cache) == 0, "Should not cache requires_grad result"

# Second call — should recompute (not return stale cached value)
result2 = mat.n(0.55)
assert result2.requires_grad, "Second result should also require grad"
assert len(mat._n_cache) == 0, "Cache should still be empty"

# Verify gradient actually flows through
loss = result2.sum()
loss.backward()
assert mat.index.grad is not None, "Gradient should flow to index"
assert mat.index.grad.abs().sum() > 0, "Gradient should be non-zero"

def test_cache_bypassed_for_k_when_requires_grad(self):
"""k() has the same bypass logic as n()."""
import torch

mat = materials.IdealMaterial(n=1.5, k=0.1)
mat.absorp = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float64))

result = mat.k(0.55)
assert result.requires_grad
assert len(mat._k_cache) == 0

loss = result.sum()
loss.backward()
assert mat.absorp.grad is not None



def build_model(material: BaseMaterial):
Expand Down
Loading