1616
1717import pytensor
1818import pytensor .tensor as pt
19- from pytensor import function
20- from pytensor .gradient import Lop , Rop , grad , grad_undefined
19+ from pytensor import config , function
20+ from pytensor .gradient import (
21+ Lop ,
22+ NullTypeGradError ,
23+ Rop ,
24+ grad ,
25+ grad_undefined ,
26+ )
2127from pytensor .graph .basic import Apply
2228from pytensor .graph .op import Op
2329from pytensor .tensor .math import argmax , dot
@@ -61,6 +67,10 @@ class RopLopChecker:
6167 Rop to class that inherit from it.
6268 """
6369
70+ @staticmethod
71+ def rtol ():
72+ return 1e-7 if config .floatX == "float64" else 1e-5
73+
6474 def setup_method (self ):
6575 # Using vectors make things a lot simpler for generating the same
6676 # computations using scan
@@ -72,13 +82,13 @@ def setup_method(self):
7282 self .mv = matrix ("mv" )
7383 self .mat_in_shape = (5 + self .rng .integers (3 ), 5 + self .rng .integers (3 ))
7484
75- def check_nondiff_rop (self , y ):
85+ def check_nondiff_rop (self , y , x , v ):
7686 """
7787 If your op is not differentiable(so you can't define Rop)
7888 test that an error is raised.
7989 """
8090 with pytest .raises (ValueError ):
81- Rop (y , self . x , self . v )
91+ Rop (y , x , v )
8292
8393 def check_mat_rop_lop (self , y , out_shape ):
8494 """
@@ -115,13 +125,13 @@ def check_mat_rop_lop(self, y, out_shape):
115125 )
116126 scan_f = function ([self .mx , self .mv ], sy , on_unused_input = "ignore" )
117127
118- v1 = rop_f (vx , vv )
119- v2 = scan_f (vx , vv )
120-
121- assert np .allclose (v1 , v2 ), f"ROP mismatch: { v1 } { v2 } "
128+ v_ref = scan_f (vx , vv )
129+ np .testing .assert_allclose (rop_f (vx , vv ), v_ref )
122130
123131 self .check_nondiff_rop (
124- pytensor .clone_replace (y , replace = {self .mx : break_op (self .mx )})
132+ pytensor .clone_replace (y , replace = {self .mx : break_op (self .mx )}),
133+ self .mx ,
134+ self .mv ,
125135 )
126136
127137 vv = np .asarray (self .rng .uniform (size = out_shape ), pytensor .config .floatX )
@@ -131,15 +141,17 @@ def check_mat_rop_lop(self, y, out_shape):
131141 sy = grad ((self .v * y ).sum (), self .mx )
132142 scan_f = function ([self .mx , self .v ], sy )
133143
134- v1 = lop_f (vx , vv )
135- v2 = scan_f (vx , vv )
136- assert np .allclose ( v1 , v2 ), f"LOP mismatch: { v1 } { v2 } "
144+ v = lop_f (vx , vv )
145+ v_ref = scan_f (vx , vv )
146+ np .testing . assert_allclose ( v , v_ref )
137147
138- def check_rop_lop (self , y , out_shape ):
148+ def check_rop_lop (self , y , out_shape , check_nondiff_rop : bool = True ):
139149 """
140150 As check_mat_rop_lop, except the input is self.x which is a
141151 vector. The output is still a vector.
142152 """
153+ rtol = self .rtol ()
154+
143155 # TEST ROP
144156 vx = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
145157 vv = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
@@ -152,24 +164,17 @@ def check_rop_lop(self, y, out_shape):
152164 non_sequences = [y , self .x ],
153165 )
154166 sy = dot (J , self .v )
155-
156167 scan_f = function ([self .x , self .v ], sy , on_unused_input = "ignore" )
157168
158- v1 = rop_f (vx , vv )
159- v2 = scan_f (vx , vv )
160- assert np .allclose (v1 , v2 ), f"ROP mismatch: { v1 } { v2 } "
169+ v_ref = scan_f (vx , vv )
170+ np .testing .assert_allclose (rop_f (vx , vv ), v_ref , rtol = rtol )
161171
162- try :
163- Rop (
172+ if check_nondiff_rop :
173+ self . check_nondiff_rop (
164174 pytensor .clone_replace (y , replace = {self .x : break_op (self .x )}),
165175 self .x ,
166176 self .v ,
167177 )
168- except ValueError :
169- pytest .skip (
170- "Rop does not handle non-differentiable inputs "
171- "correctly. Bug exposed by fixing Add.grad method."
172- )
173178
174179 vx = np .asarray (self .rng .uniform (size = self .in_shape ), pytensor .config .floatX )
175180 vv = np .asarray (self .rng .uniform (size = out_shape ), pytensor .config .floatX )
@@ -182,22 +187,20 @@ def check_rop_lop(self, y, out_shape):
182187 non_sequences = [y , self .x ],
183188 )
184189 sy = dot (self .v , J )
185-
186190 scan_f = function ([self .x , self .v ], sy )
187191
188- v1 = lop_f (vx , vv )
189- v2 = scan_f (vx , vv )
190- assert np .allclose ( v1 , v2 ), f"LOP mismatch: { v1 } { v2 } "
192+ v = lop_f (vx , vv )
193+ v_ref = scan_f (vx , vv )
194+ np .testing . assert_allclose ( v , v_ref , rtol = rtol )
191195
192196
193197class TestRopLop (RopLopChecker ):
194198 def test_max (self ):
195- # self.check_mat_rop_lop(pt_max(self.mx, axis=[0,1])[0], ())
196199 self .check_mat_rop_lop (pt_max (self .mx , axis = 0 ), (self .mat_in_shape [1 ],))
197200 self .check_mat_rop_lop (pt_max (self .mx , axis = 1 ), (self .mat_in_shape [0 ],))
198201
199202 def test_argmax (self ):
200- self .check_nondiff_rop (argmax (self .mx , axis = 1 ))
203+ self .check_nondiff_rop (argmax (self .mx , axis = 1 ), self . mx , self . mv )
201204
202205 def test_subtensor (self ):
203206 self .check_rop_lop (self .x [:4 ], (4 ,))
@@ -252,10 +255,14 @@ def test_dot(self):
252255 insh = self .in_shape [0 ]
253256 vW = np .asarray (self .rng .uniform (size = (insh , insh )), pytensor .config .floatX )
254257 W = pytensor .shared (vW )
255- self .check_rop_lop (dot (self .x , W ), self .in_shape )
258+ # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
259+ # See: test_Rop_partially_differentiable_paths
260+ self .check_rop_lop (dot (self .x , W ), self .in_shape , check_nondiff_rop = False )
256261
257262 def test_elemwise0 (self ):
258- self .check_rop_lop ((self .x + 1 ) ** 2 , self .in_shape )
263+ # check_nondiff_rop reveals an error in how Rop handles non-differentiable paths
264+ # See: test_Rop_partially_differentiable_paths
265+ self .check_rop_lop ((self .x + 1 ) ** 2 , self .in_shape , check_nondiff_rop = False )
259266
260267 def test_elemwise1 (self ):
261268 self .check_rop_lop (self .x + pt .cast (self .x , "int32" ), self .in_shape )
@@ -288,15 +295,8 @@ def test_alloc(self):
288295 )
289296
290297 def test_invalid_input (self ):
291- success = False
292-
293- try :
298+ with pytest .raises (ValueError ):
294299 Rop (0.0 , [matrix ()], [vector ()])
295- success = True
296- except ValueError :
297- pass
298-
299- assert not success
300300
301301 def test_multiple_outputs (self ):
302302 m = matrix ("m" )
@@ -322,12 +322,54 @@ def test_multiple_outputs(self):
322322 f = pytensor .function ([m , v , m_ , v_ ], all_outs )
323323 f (mval , vval , m_val , v_val )
324324
325- def test_Rop_dot_bug_18Oct2013_Jeremiah (self ):
325+ @pytest .mark .xfail ()
326+ def test_Rop_partially_differentiable_paths (self ):
326327 # This test refers to a bug reported by Jeremiah Lowin on 18th Oct
327328 # 2013. The bug consists when through a dot operation there is only
328329 # one differentiable path (i.e. there is no gradient wrt to one of
329330 # the inputs).
330331 x = pt .arange (20.0 ).reshape ([1 , 20 ])
331- v = pytensor .shared (np .ones ([20 ]))
332+ v = pytensor .shared (np .ones ([20 ]), name = "v" )
332333 d = dot (x , v ).sum ()
333- Rop (grad (d , v ), v , v )
334+
335+ Rop (
336+ grad (d , v ),
337+ v ,
338+ v ,
339+ disconnected_outputs = "raise" ,
340+ )
341+
342+ # 2025: Here is an unambiguous test for the original commented issue:
343+ x = pt .matrix ("x" )
344+ y = pt .matrix ("y" )
345+ out = dot (x , break_op (y )).sum ()
346+ # Should not raise an error
347+ Rop (
348+ out ,
349+ [x ],
350+ [x .type ()],
351+ disconnected_outputs = "raise" ,
352+ )
353+
354+ # More extensive testing shows that the Rop implementation FAILS to raise when
355+ # the cost is linked through strictly non-differentiable paths.
356+ # This is not Dot specific, we would observe the same with any operation where the gradient
357+ # with respect to one of the inputs does not depend on the original input (such as `mul`, `add`, ...)
358+ out = dot (break_op (x ), y ).sum ()
359+ with pytest .raises ((ValueError , NullTypeGradError )):
360+ Rop (
361+ out ,
362+ [x ],
363+ [x .type ()],
364+ disconnected_outputs = "raise" ,
365+ )
366+
367+ # Only when both paths are non-differentiable is an error correctly raised again.
368+ out = dot (break_op (x ), break_op (y )).sum ()
369+ with pytest .raises ((ValueError , NullTypeGradError )):
370+ Rop (
371+ out ,
372+ [x ],
373+ [x .type ()],
374+ disconnected_outputs = "raise" ,
375+ )
0 commit comments