@@ -55,6 +55,170 @@ def _calculate_k(self, wavelength, **kwargs):
5555 # and a cache hit
5656 material .n (wavelength_torch_2 , temperature = 25 )
5757 assert material ._calculate_n .call_count == 3
58+ def test_detach_if_tensor_plain_value (self , set_test_backend ):
59+ """_detach_if_tensor returns plain values unchanged."""
60+ assert BaseMaterial ._detach_if_tensor (1.5 ) == 1.5
61+ assert BaseMaterial ._detach_if_tensor (None ) is None
62+
63+ def test_detach_if_tensor_numpy (self , set_test_backend ):
64+ """_detach_if_tensor returns numpy arrays unchanged."""
65+ arr = np .array ([1.5 , 1.6 ])
66+ result = BaseMaterial ._detach_if_tensor (arr )
67+ np .testing .assert_array_equal (result , arr )
68+
69+ def test_requires_grad_plain_value (self , set_test_backend ):
70+ """_requires_grad returns False for plain values."""
71+ assert BaseMaterial ._requires_grad (1.5 ) is False
72+ assert BaseMaterial ._requires_grad (np .array (1.5 )) is False
73+
74+
75+ @pytest .mark .skipif (
76+ "torch" not in be .list_available_backends (),
77+ reason = "PyTorch not installed" ,
78+ )
79+ class TestBaseMaterialTorchCaching :
80+ """Tests for material caching behavior specific to torch backend.
81+
82+ These tests verify:
83+ - Cached tensors are detached (no stale computation graph)
84+ - Repeated forward+backward passes don't raise RuntimeError
85+ - Cache is bypassed when the result requires grad (optimizable index)
86+ """
87+
88+ @pytest .fixture (autouse = True )
89+ def _setup_torch (self ):
90+ be .set_backend ("torch" )
91+ be .set_device ("cpu" )
92+ be .grad_mode .enable ()
93+ be .set_precision ("float64" )
94+ yield
95+ be .set_backend ("numpy" )
96+
97+ def test_detach_if_tensor_torch (self ):
98+ """_detach_if_tensor detaches torch tensors."""
99+ import torch
100+
101+ t = torch .tensor (1.5 , requires_grad = True ) * 2.0 # has grad_fn
102+ result = BaseMaterial ._detach_if_tensor (t )
103+ assert not result .requires_grad
104+ assert result .grad_fn is None
105+
106+ def test_requires_grad_torch (self ):
107+ """_requires_grad correctly detects torch tensors with grad."""
108+ import torch
109+
110+ leaf = torch .tensor (1.5 , requires_grad = True )
111+ assert BaseMaterial ._requires_grad (leaf ) is True
112+
113+ no_grad = torch .tensor (1.5 )
114+ assert BaseMaterial ._requires_grad (no_grad ) is False
115+
116+ def test_cached_value_is_detached_when_no_grad (self ):
117+ """Cached n/k values must be detached when they don't require grad.
118+
119+ When grad_mode is off, Sellmeier formula results are torch tensors
120+ without requires_grad. These should be cached and detached to prevent
121+ stale computation graph references if grad_mode is later re-enabled.
122+ """
123+ import torch
124+
125+ be .grad_mode .disable ()
126+ try :
127+ mat = materials .Material ("N-BK7" )
128+ result = mat .n (0.5876 )
129+
130+ # Should be cached since result doesn't require grad
131+ assert len (mat ._n_cache ) > 0 , "Result should be cached"
132+
133+ cached = list (mat ._n_cache .values ())[0 ]
134+ if isinstance (cached , torch .Tensor ):
135+ assert cached .grad_fn is None , (
136+ "Cached tensor should be detached (no grad_fn)"
137+ )
138+ finally :
139+ be .grad_mode .enable ()
140+
141+ def test_cache_bypassed_for_real_material_under_grad_mode (self ):
142+ """Under torch + grad_mode, real material results require_grad=True
143+ because be.array() creates grad-enabled tensors. The cache must be
144+ bypassed in this case.
145+ """
146+ mat = materials .Material ("N-BK7" )
147+ mat .n (0.5876 )
148+
149+ assert len (mat ._n_cache ) == 0 , (
150+ "Cache should be empty — Sellmeier result requires_grad under "
151+ "grad_mode, so caching must be skipped"
152+ )
153+
154+ def test_no_runtime_error_on_repeated_backward (self ):
155+ """Repeated forward+backward passes must not raise RuntimeError.
156+
157+ This is the regression test for the 'backward through the graph a
158+ second time' error. We simulate a multi-step optimization by calling
159+ material.n() repeatedly and computing a loss that flows through it.
160+ """
161+ import torch
162+
163+ mat = materials .Material ("N-BK7" )
164+
165+ # Simulate multiple optimizer steps
166+ param = torch .nn .Parameter (torch .tensor (0.55 , dtype = torch .float64 ))
167+
168+ for _ in range (3 ):
169+ # Forward: compute n at a wavelength that depends on param
170+ n_val = mat .n (param )
171+ loss = (n_val - 1.5 ) ** 2
172+ loss .backward ()
173+
174+ with torch .no_grad ():
175+ param .data -= 0.01 * param .grad
176+ param .grad .zero_ ()
177+
178+ def test_cache_bypassed_when_result_requires_grad (self ):
179+ """n() must NOT cache when the result requires_grad.
180+
181+ If the refractive index itself is an optimization variable (e.g.
182+ IdealMaterial with a nn.Parameter index), caching would break
183+ gradient flow — grad(loss, n) would always be zero.
184+ """
185+ import torch
186+
187+ # Create an IdealMaterial whose index is a nn.Parameter
188+ mat = materials .IdealMaterial (n = 1.5 )
189+ mat .index = torch .nn .Parameter (torch .tensor ([1.5 ], dtype = torch .float64 ))
190+
191+ # First call — result requires grad, should NOT be cached
192+ result1 = mat .n (0.55 )
193+ assert result1 .requires_grad , "Result should require grad"
194+ assert len (mat ._n_cache ) == 0 , "Should not cache requires_grad result"
195+
196+ # Second call — should recompute (not return stale cached value)
197+ result2 = mat .n (0.55 )
198+ assert result2 .requires_grad , "Second result should also require grad"
199+ assert len (mat ._n_cache ) == 0 , "Cache should still be empty"
200+
201+ # Verify gradient actually flows through
202+ loss = result2 .sum ()
203+ loss .backward ()
204+ assert mat .index .grad is not None , "Gradient should flow to index"
205+ assert mat .index .grad .abs ().sum () > 0 , "Gradient should be non-zero"
206+
207+ def test_cache_bypassed_for_k_when_requires_grad (self ):
208+ """k() has the same bypass logic as n()."""
209+ import torch
210+
211+ mat = materials .IdealMaterial (n = 1.5 , k = 0.1 )
212+ mat .absorp = torch .nn .Parameter (torch .tensor ([0.1 ], dtype = torch .float64 ))
213+
214+ result = mat .k (0.55 )
215+ assert result .requires_grad
216+ assert len (mat ._k_cache ) == 0
217+
218+ loss = result .sum ()
219+ loss .backward ()
220+ assert mat .absorp .grad is not None
221+
58222
59223
60224def build_model (material : BaseMaterial ):
0 commit comments