diff --git a/optiland/materials/base.py b/optiland/materials/base.py index 2194be1a8..bc1db4c1b 100644 --- a/optiland/materials/base.py +++ b/optiland/materials/base.py @@ -78,6 +78,23 @@ def _create_cache_key(self, wavelength: float | be.ndarray, **kwargs) -> tuple: wavelength_key = wavelength return (wavelength_key,) + tuple(sorted(kwargs.items())) + @staticmethod + def _requires_grad(value) -> bool: + """Check if a value is a torch tensor that requires gradient.""" + return hasattr(value, "requires_grad") and value.requires_grad + + @staticmethod + def _detach_if_tensor(value): + """Detach a torch tensor to sever the computation graph link. + + This prevents the 'backward through the graph a second time' error + that occurs when a cached tensor still references a freed computation + graph. + """ + if hasattr(value, "detach"): + return value.detach() + return value + def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray: """Calculates the refractive index at a given wavelength with caching. @@ -95,8 +112,17 @@ def n(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray: return self._n_cache[cache_key] result = self._calculate_n(wavelength, **kwargs) - self._n_cache[cache_key] = result - return result + + # If the result requires grad, it is connected to an optimization + # variable (e.g. the index itself is being optimized). In that case + # we must NOT cache — every forward pass needs a fresh graph. + if self._requires_grad(result): + return result + + # Otherwise the value is a constant w.r.t. optimization variables. + # Detach before caching to avoid holding a stale computation graph. + self._n_cache[cache_key] = self._detach_if_tensor(result) + return self._n_cache[cache_key] def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray: """Calculates the extinction coefficient at a given wavelength with caching. @@ -115,8 +141,12 @@ def k(self, wavelength: float | be.ndarray, **kwargs) -> float | be.ndarray: return self._k_cache[cache_key] result = self._calculate_k(wavelength, **kwargs) - self._k_cache[cache_key] = result - return result + # Same logic as n(): skip cache if result is differentiable. + if self._requires_grad(result): + return result + + self._k_cache[cache_key] = self._detach_if_tensor(result) + return self._k_cache[cache_key] @abstractmethod def _calculate_n( diff --git a/optiland/optimization/optimizer/torch/base.py b/optiland/optimization/optimizer/torch/base.py index dc905ceb7..afb8e5f70 100644 --- a/optiland/optimization/optimizer/torch/base.py +++ b/optiland/optimization/optimizer/torch/base.py @@ -53,7 +53,8 @@ def __init__(self, problem: OptimizationProblem): be.grad_mode.enable() # Initialize parameters as torch.nn.Parameter objects - initial_params = [var.variable.get_value() for var in self.problem.variables] + # Use var.value (scaled) to match the scaled bounds from var.bounds + initial_params = [var.value for var in self.problem.variables] self.params = [torch.nn.Parameter(be.array(p)) for p in initial_params] @abstractmethod @@ -118,8 +119,9 @@ def optimize( optimizer.zero_grad() # 1. Update the model state from the current nn.Parameter values. + # Use var.update() which inverse-scales from scaled space for k, param in enumerate(self.params): - self.problem.variables[k].variable.update_value(param) + self.problem.variables[k].update(param) # 2. Update any dependent properties. self.problem.update_optics() @@ -147,7 +149,7 @@ def optimize( # Final update to ensure the model reflects the last optimized state for k, param in enumerate(self.params): - self.problem.variables[k].variable.update_value(param) + self.problem.variables[k].update(param) self.problem.update_optics() final_loss = self.problem.sum_squared().item() diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 5893cead3..fa0a8cd9e 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -1754,6 +1754,16 @@ def test_view_with_all_zero_intensities(self, tf_spot): def read_zmx_file(file_path, skip_lines, cols=(0, 1)): + import os + if not os.path.exists(file_path): + this_dir = os.path.dirname(os.path.abspath(__file__)) + if file_path.startswith("tests/") or file_path.startswith("tests\\"): + candidate = os.path.join(this_dir, file_path[6:]) + else: + candidate = os.path.join(this_dir, file_path) + if os.path.exists(candidate): + file_path = candidate + try: data = np.loadtxt( file_path, skiprows=skip_lines, usecols=cols, encoding="utf-16" diff --git a/tests/test_materials.py b/tests/test_materials.py index 013a854a6..840a2aa98 100644 --- a/tests/test_materials.py +++ b/tests/test_materials.py @@ -55,6 +55,170 @@ def _calculate_k(self, wavelength, **kwargs): # and a cache hit material.n(wavelength_torch_2, temperature=25) assert material._calculate_n.call_count == 3 + def test_detach_if_tensor_plain_value(self, set_test_backend): + """_detach_if_tensor returns plain values unchanged.""" + assert BaseMaterial._detach_if_tensor(1.5) == 1.5 + assert BaseMaterial._detach_if_tensor(None) is None + + def test_detach_if_tensor_numpy(self, set_test_backend): + """_detach_if_tensor returns numpy arrays unchanged.""" + arr = np.array([1.5, 1.6]) + result = BaseMaterial._detach_if_tensor(arr) + np.testing.assert_array_equal(result, arr) + + def test_requires_grad_plain_value(self, set_test_backend): + """_requires_grad returns False for plain values.""" + assert BaseMaterial._requires_grad(1.5) is False + assert BaseMaterial._requires_grad(np.array(1.5)) is False + + +@pytest.mark.skipif( + "torch" not in be.list_available_backends(), + reason="PyTorch not installed", +) +class TestBaseMaterialTorchCaching: + """Tests for material caching behavior specific to torch backend. + + These tests verify: + - Cached tensors are detached (no stale computation graph) + - Repeated forward+backward passes don't raise RuntimeError + - Cache is bypassed when the result requires grad (optimizable index) + """ + + @pytest.fixture(autouse=True) + def _setup_torch(self): + be.set_backend("torch") + be.set_device("cpu") + be.grad_mode.enable() + be.set_precision("float64") + yield + be.set_backend("numpy") + + def test_detach_if_tensor_torch(self): + """_detach_if_tensor detaches torch tensors.""" + import torch + + t = torch.tensor(1.5, requires_grad=True) * 2.0 # has grad_fn + result = BaseMaterial._detach_if_tensor(t) + assert not result.requires_grad + assert result.grad_fn is None + + def test_requires_grad_torch(self): + """_requires_grad correctly detects torch tensors with grad.""" + import torch + + leaf = torch.tensor(1.5, requires_grad=True) + assert BaseMaterial._requires_grad(leaf) is True + + no_grad = torch.tensor(1.5) + assert BaseMaterial._requires_grad(no_grad) is False + + def test_cached_value_is_detached_when_no_grad(self): + """Cached n/k values must be detached when they don't require grad. + + When grad_mode is off, Sellmeier formula results are torch tensors + without requires_grad. These should be cached and detached to prevent + stale computation graph references if grad_mode is later re-enabled. + """ + import torch + + be.grad_mode.disable() + try: + mat = materials.Material("N-BK7") + result = mat.n(0.5876) + + # Should be cached since result doesn't require grad + assert len(mat._n_cache) > 0, "Result should be cached" + + cached = list(mat._n_cache.values())[0] + if isinstance(cached, torch.Tensor): + assert cached.grad_fn is None, ( + "Cached tensor should be detached (no grad_fn)" + ) + finally: + be.grad_mode.enable() + + def test_cache_bypassed_for_real_material_under_grad_mode(self): + """Under torch + grad_mode, real material results require_grad=True + because be.array() creates grad-enabled tensors. The cache must be + bypassed in this case. + """ + mat = materials.Material("N-BK7") + mat.n(0.5876) + + assert len(mat._n_cache) == 0, ( + "Cache should be empty — Sellmeier result requires_grad under " + "grad_mode, so caching must be skipped" + ) + + def test_no_runtime_error_on_repeated_backward(self): + """Repeated forward+backward passes must not raise RuntimeError. + + This is the regression test for the 'backward through the graph a + second time' error. We simulate a multi-step optimization by calling + material.n() repeatedly and computing a loss that flows through it. + """ + import torch + + mat = materials.Material("N-BK7") + + # Simulate multiple optimizer steps + param = torch.nn.Parameter(torch.tensor(0.55, dtype=torch.float64)) + + for _ in range(3): + # Forward: compute n at a wavelength that depends on param + n_val = mat.n(param) + loss = (n_val - 1.5) ** 2 + loss.backward() + + with torch.no_grad(): + param.data -= 0.01 * param.grad + param.grad.zero_() + + def test_cache_bypassed_when_result_requires_grad(self): + """n() must NOT cache when the result requires_grad. + + If the refractive index itself is an optimization variable (e.g. + IdealMaterial with a nn.Parameter index), caching would break + gradient flow — grad(loss, n) would always be zero. + """ + import torch + + # Create an IdealMaterial whose index is a nn.Parameter + mat = materials.IdealMaterial(n=1.5) + mat.index = torch.nn.Parameter(torch.tensor([1.5], dtype=torch.float64)) + + # First call — result requires grad, should NOT be cached + result1 = mat.n(0.55) + assert result1.requires_grad, "Result should require grad" + assert len(mat._n_cache) == 0, "Should not cache requires_grad result" + + # Second call — should recompute (not return stale cached value) + result2 = mat.n(0.55) + assert result2.requires_grad, "Second result should also require grad" + assert len(mat._n_cache) == 0, "Cache should still be empty" + + # Verify gradient actually flows through + loss = result2.sum() + loss.backward() + assert mat.index.grad is not None, "Gradient should flow to index" + assert mat.index.grad.abs().sum() > 0, "Gradient should be non-zero" + + def test_cache_bypassed_for_k_when_requires_grad(self): + """k() has the same bypass logic as n().""" + import torch + + mat = materials.IdealMaterial(n=1.5, k=0.1) + mat.absorp = torch.nn.Parameter(torch.tensor([0.1], dtype=torch.float64)) + + result = mat.k(0.55) + assert result.requires_grad + assert len(mat._k_cache) == 0 + + loss = result.sum() + loss.backward() + assert mat.absorp.grad is not None + def build_model(material: BaseMaterial): diff --git a/tests/test_torch_optimization.py b/tests/test_torch_optimization.py index e66012674..900ff1178 100644 --- a/tests/test_torch_optimization.py +++ b/tests/test_torch_optimization.py @@ -212,3 +212,150 @@ def test_display_output_controlled_by_disp_flag(self, capsys, optimizer_class, d assert "Loss" in captured.out else: assert "Loss" not in captured.out + + +class TestTorchOptimizerScaledSpace: + """ + Tests that verify the Torch optimizers work in scaled parameter space, + consistent with the bounds from var.bounds. + + This prevents the bug where raw parameters (e.g. radius=15.0) were clamped + by scaled bounds (e.g. [-0.9, -0.7]), corrupting the value and producing + NaN loss. + """ + + def test_params_initialized_in_scaled_space(self): + """ + Optimizer params must match var.value (scaled), not + var.variable.get_value() (raw). + + For a RadiusVariable with LinearScaler(factor=1/100, offset=-1.0), + raw=5.0 should become scaled = 5.0 * 0.01 - 1.0 = -0.95. + """ + problem, _ = setup_problem(initial_value=5.0, min_val=1.0, max_val=10.0) + optimizer = TorchAdamOptimizer(problem) + + # The optimizer param should be the scaled value, not the raw value + scaled_param = optimizer.params[0].item() + raw_value = problem.variables[0].variable.get_value() + expected_scaled = problem.variables[0].value + + assert abs(scaled_param - expected_scaled) < 1e-10, ( + f"Param {scaled_param} should equal scaled value {expected_scaled}, " + f"not raw value {raw_value}" + ) + # Sanity: scaled and raw should differ for this scaler + assert abs(scaled_param - raw_value) > 0.1, ( + "Scaled and raw values should be different" + ) + + def test_bounds_consistent_with_params(self): + """ + After _apply_bounds(), the parameter must remain in the valid scaled + range — not be corrupted by a space mismatch. + """ + problem, _ = setup_problem(initial_value=5.0, min_val=1.0, max_val=10.0) + optimizer = TorchAdamOptimizer(problem) + + min_bound, max_bound = problem.variables[0].bounds + + # Apply bounds and verify param stays in range (with float tolerance) + optimizer._apply_bounds() + param_val = optimizer.params[0].item() + + tol = 1e-6 + assert min_bound - tol <= param_val <= max_bound + tol, ( + f"Param {param_val} outside scaled bounds [{min_bound}, {max_bound}]" + ) + + def test_bounded_optimization_no_nan(self): + """ + Regression test: optimization with bounded variables must not produce + NaN loss. This was the exact failure mode when raw params were clamped + by scaled bounds. + """ + import math + + problem, lens = setup_problem( + initial_value=5.0, + min_val=1.0, + max_val=10.0, + target=12.0, + ) + optimizer = TorchAdamOptimizer(problem) + result = optimizer.optimize(n_steps=20, disp=False) + + assert not math.isnan(result.fun), ( + f"Loss should not be NaN, got {result.fun}" + ) + assert result.fun >= 0.0, ( + f"Loss should be non-negative, got {result.fun}" + ) + + def test_bounded_variable_stays_in_physical_range(self): + """ + After optimization with bounds [1, 10], the actual radius on the + optic must be within that physical range — not corrupted to a scaled + value like -0.7. + """ + problem, lens = setup_problem( + initial_value=5.0, + min_val=1.0, + max_val=10.0, + target=12.0, + ) + optimizer = TorchAdamOptimizer(problem) + optimizer.optimize(n_steps=10, disp=False) + + # Get the actual physical radius from the optic + raw_radius = problem.variables[0].variable.get_value() + if hasattr(raw_radius, "item"): + raw_radius = raw_radius.item() + + # Allow small floating-point overshoot from inverse-scaling + tol = 1e-4 + assert 1.0 - tol <= raw_radius <= 10.0 + tol, ( + f"Physical radius {raw_radius} outside bounds [1.0, 10.0]. " + "This suggests a scaled/unscaled space mismatch." + ) + + def test_optimizer_with_real_material_no_nan(self): + """ + End-to-end regression test using Material("N-BK7") with bounded + variables — the exact combination that triggered both the NaN bug + (scaled/unscaled mismatch) and the RuntimeError (stale graph in + cached material tensors). + """ + import math + + from optiland.optic import Optic + + lens = Optic() + lens.add_surface(index=0, thickness=be.inf) + lens.add_surface( + index=1, thickness=7, radius=15, material="N-BK7", is_stop=True, + ) + lens.add_surface(index=2, thickness=30, radius=-1000) + lens.add_surface(index=3) + lens.set_aperture(aperture_type="EPD", value=15) + lens.set_field_type(field_type="angle") + lens.add_field(y=0) + lens.add_wavelength(value=0.55, is_primary=True) + + problem = OptimizationProblem() + problem.add_operand( + operand_type="f2", target=50, weight=1, + input_data={"optic": lens}, + ) + problem.add_variable( + lens, "radius", surface_number=1, min_val=10, max_val=30, + ) + problem.add_variable(lens, "thickness", surface_number=2) + + optimizer = TorchAdamOptimizer(problem) + result = optimizer.optimize(n_steps=10, disp=False) + + assert not math.isnan(result.fun), ( + f"Loss should not be NaN with Material('N-BK7') and bounded " + f"variables, got {result.fun}" + ) \ No newline at end of file