Skip to content

Commit 36a45d9

Browse files
committed
Solves #483 and #484
1 parent fd572a2 commit 36a45d9

File tree

5 files changed

+360
-7
lines changed

5 files changed

+360
-7
lines changed

optiland/materials/base.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ def _create_cache_key(self, wavelength: float | be.ndarray, **kwargs) -> tuple:
7878
wavelength_key = wavelength
7979
return (wavelength_key,) + tuple(sorted(kwargs.items()))
8080

81+
@staticmethod
82+
def _requires_grad(value) -> bool:
83+
"""Check if a value is a torch tensor that requires gradient."""
84+
return hasattr(value, "requires_grad") and value.requires_grad
85+
86+
@staticmethod
87+
def _detach_if_tensor(value):
88+
"""Detach a torch tensor to sever the computation graph link.
89+
90+
This prevents the 'backward through the graph a second time' error
91+
that occurs when a cached tensor still references a freed computation
92+
graph.
93+
"""
94+
if hasattr(value, "detach"):
95+
return value.detach()
96+
return value
97+
8198
def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
8299
"""Calculates the refractive index at a given wavelength with caching.
83100
@@ -95,8 +112,17 @@ def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
95112
return self._n_cache[cache_key]
96113

97114
result = self._calculate_n(wavelength, **kwargs)
98-
self._n_cache[cache_key] = result
99-
return result
115+
116+
# If the result requires grad, it is connected to an optimization
117+
# variable (e.g. the index itself is being optimized). In that case
118+
# we must NOT cache — every forward pass needs a fresh graph.
119+
if self._requires_grad(result):
120+
return result
121+
122+
# Otherwise the value is a constant w.r.t. optimization variables.
123+
# Detach before caching to avoid holding a stale computation graph.
124+
self._n_cache[cache_key] = self._detach_if_tensor(result)
125+
return self._n_cache[cache_key]
100126

101127
def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
102128
"""Calculates the extinction coefficient at a given wavelength with caching.
@@ -115,8 +141,12 @@ def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray:
115141
return self._k_cache[cache_key]
116142

117143
result = self._calculate_k(wavelength, **kwargs)
118-
self._k_cache[cache_key] = result
119-
return result
144+
# Same logic as n(): skip cache if result is differentiable.
145+
if self._requires_grad(result):
146+
return result
147+
148+
self._k_cache[cache_key] = self._detach_if_tensor(result)
149+
return self._k_cache[cache_key]
120150

121151
@abstractmethod
122152
def _calculate_n(

optiland/optimization/optimizer/torch/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def __init__(self, problem: OptimizationProblem):
5353
be.grad_mode.enable()
5454

5555
# Initialize parameters as torch.nn.Parameter objects
56-
initial_params = [var.variable.get_value() for var in self.problem.variables]
56+
# Use var.value (scaled) to match the scaled bounds from var.bounds
57+
initial_params = [var.value for var in self.problem.variables]
5758
self.params = [torch.nn.Parameter(be.array(p)) for p in initial_params]
5859

5960
@abstractmethod
@@ -118,8 +119,9 @@ def optimize(
118119
optimizer.zero_grad()
119120

120121
# 1. Update the model state from the current nn.Parameter values.
122+
# Use var.update() which inverse-scales from scaled space
121123
for k, param in enumerate(self.params):
122-
self.problem.variables[k].variable.update_value(param)
124+
self.problem.variables[k].update(param)
123125

124126
# 2. Update any dependent properties.
125127
self.problem.update_optics()
@@ -147,7 +149,7 @@ def optimize(
147149

148150
# Final update to ensure the model reflects the last optimized state
149151
for k, param in enumerate(self.params):
150-
self.problem.variables[k].variable.update_value(param)
152+
self.problem.variables[k].update(param)
151153
self.problem.update_optics()
152154

153155
final_loss = self.problem.sum_squared().item()

tests/test_analysis.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1754,6 +1754,16 @@ def test_view_with_all_zero_intensities(self, tf_spot):
17541754

17551755

17561756
def read_zmx_file(file_path, skip_lines, cols=(0, 1)):
1757+
import os
1758+
if not os.path.exists(file_path):
1759+
this_dir = os.path.dirname(os.path.abspath(__file__))
1760+
if file_path.startswith("tests/") or file_path.startswith("tests\\"):
1761+
candidate = os.path.join(this_dir, file_path[6:])
1762+
else:
1763+
candidate = os.path.join(this_dir, file_path)
1764+
if os.path.exists(candidate):
1765+
file_path = candidate
1766+
17571767
try:
17581768
data = np.loadtxt(
17591769
file_path, skiprows=skip_lines, usecols=cols, encoding="utf-16"

tests/test_materials.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,170 @@ def _calculate_k(self, wavelength, **kwargs):
5555
# and a cache hit
5656
material.n(wavelength_torch_2, temperature=25)
5757
assert material._calculate_n.call_count == 3
58+
def test_detach_if_tensor_plain_value(self, set_test_backend):
59+
"""_detach_if_tensor returns plain values unchanged."""
60+
assert BaseMaterial._detach_if_tensor(1.5) == 1.5
61+
assert BaseMaterial._detach_if_tensor(None) is None
62+
63+
def test_detach_if_tensor_numpy(self, set_test_backend):
64+
"""_detach_if_tensor returns numpy arrays unchanged."""
65+
arr = np.array([1.5, 1.6])
66+
result = BaseMaterial._detach_if_tensor(arr)
67+
np.testing.assert_array_equal(result, arr)
68+
69+
def test_requires_grad_plain_value(self, set_test_backend):
70+
"""_requires_grad returns False for plain values."""
71+
assert BaseMaterial._requires_grad(1.5) is False
72+
assert BaseMaterial._requires_grad(np.array(1.5)) is False
73+
74+
75+
@pytest.mark.skipif(
76+
"torch" not in be.list_available_backends(),
77+
reason="PyTorch not installed",
78+
)
79+
class TestBaseMaterialTorchCaching:
80+
"""Tests for material caching behavior specific to torch backend.
81+
82+
These tests verify:
83+
- Cached tensors are detached (no stale computation graph)
84+
- Repeated forward+backward passes don't raise RuntimeError
85+
- Cache is bypassed when the result requires grad (optimizable index)
86+
"""
87+
88+
@pytest.fixture(autouse=True)
89+
def _setup_torch(self):
90+
be.set_backend("torch")
91+
be.set_device("cpu")
92+
be.grad_mode.enable()
93+
be.set_precision("float64")
94+
yield
95+
be.set_backend("numpy")
96+
97+
def test_detach_if_tensor_torch(self):
98+
"""_detach_if_tensor detaches torch tensors."""
99+
import torch
100+
101+
t = torch.tensor(1.5, requires_grad=True) * 2.0 # has grad_fn
102+
result = BaseMaterial._detach_if_tensor(t)
103+
assert not result.requires_grad
104+
assert result.grad_fn is None
105+
106+
def test_requires_grad_torch(self):
107+
"""_requires_grad correctly detects torch tensors with grad."""
108+
import torch
109+
110+
leaf = torch.tensor(1.5, requires_grad=True)
111+
assert BaseMaterial._requires_grad(leaf) is True
112+
113+
no_grad = torch.tensor(1.5)
114+
assert BaseMaterial._requires_grad(no_grad) is False
115+
116+
def test_cached_value_is_detached_when_no_grad(self):
117+
"""Cached n/k values must be detached when they don't require grad.
118+
119+
When grad_mode is off, Sellmeier formula results are torch tensors
120+
without requires_grad. These should be cached and detached to prevent
121+
stale computation graph references if grad_mode is later re-enabled.
122+
"""
123+
import torch
124+
125+
be.grad_mode.disable()
126+
try:
127+
mat = materials.Material("N-BK7")
128+
result = mat.n(0.5876)
129+
130+
# Should be cached since result doesn't require grad
131+
assert len(mat._n_cache) > 0, "Result should be cached"
132+
133+
cached = list(mat._n_cache.values())[0]
134+
if isinstance(cached, torch.Tensor):
135+
assert cached.grad_fn is None, (
136+
"Cached tensor should be detached (no grad_fn)"
137+
)
138+
finally:
139+
be.grad_mode.enable()
140+
141+
def test_cache_bypassed_for_real_material_under_grad_mode(self):
142+
"""Under torch + grad_mode, real material results require_grad=True
143+
because be.array() creates grad-enabled tensors. The cache must be
144+
bypassed in this case.
145+
"""
146+
mat = materials.Material("N-BK7")
147+
mat.n(0.5876)
148+
149+
assert len(mat._n_cache) == 0, (
150+
"Cache should be empty — Sellmeier result requires_grad under "
151+
"grad_mode, so caching must be skipped"
152+
)
153+
154+
def test_no_runtime_error_on_repeated_backward(self):
155+
"""Repeated forward+backward passes must not raise RuntimeError.
156+
157+
This is the regression test for the 'backward through the graph a
158+
second time' error. We simulate a multi-step optimization by calling
159+
material.n() repeatedly and computing a loss that flows through it.
160+
"""
161+
import torch
162+
163+
mat = materials.Material("N-BK7")
164+
165+
# Simulate multiple optimizer steps
166+
param = torch.nn.Parameter(torch.tensor(0.55, dtype=torch.float64))
167+
168+
for _ in range(3):
169+
# Forward: compute n at a wavelength that depends on param
170+
n_val = mat.n(param)
171+
loss = (n_val - 1.5) ** 2
172+
loss.backward()
173+
174+
with torch.no_grad():
175+
param.data -= 0.01 * param.grad
176+
param.grad.zero_()
177+
178+
def test_cache_bypassed_when_result_requires_grad(self):
179+
"""n() must NOT cache when the result requires_grad.
180+
181+
If the refractive index itself is an optimization variable (e.g.
182+
IdealMaterial with a nn.Parameter index), caching would break
183+
gradient flow — grad(loss, n) would always be zero.
184+
"""
185+
import torch
186+
187+
# Create an IdealMaterial whose index is a nn.Parameter
188+
mat = materials.IdealMaterial(n=1.5)
189+
mat.index = torch.nn.Parameter(torch.tensor([1.5], dtype=torch.float64))
190+
191+
# First call — result requires grad, should NOT be cached
192+
result1 = mat.n(0.55)
193+
assert result1.requires_grad, "Result should require grad"
194+
assert len(mat._n_cache) == 0, "Should not cache requires_grad result"
195+
196+
# Second call — should recompute (not return stale cached value)
197+
result2 = mat.n(0.55)
198+
assert result2.requires_grad, "Second result should also require grad"
199+
assert len(mat._n_cache) == 0, "Cache should still be empty"
200+
201+
# Verify gradient actually flows through
202+
loss = result2.sum()
203+
loss.backward()
204+
assert mat.index.grad is not None, "Gradient should flow to index"
205+
assert mat.index.grad.abs().sum() > 0, "Gradient should be non-zero"
206+
207+
def test_cache_bypassed_for_k_when_requires_grad(self):
208+
"""k() has the same bypass logic as n()."""
209+
import torch
210+
211+
mat = materials.IdealMaterial(n=1.5, k=0.1)
212+
mat.absorp = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float64))
213+
214+
result = mat.k(0.55)
215+
assert result.requires_grad
216+
assert len(mat._k_cache) == 0
217+
218+
loss = result.sum()
219+
loss.backward()
220+
assert mat.absorp.grad is not None
221+
58222

59223

60224
def build_model(material: BaseMaterial):

0 commit comments

Comments
 (0)