Skip to content

Commit 4fa414c

Browse files
committed
Do not coerce gradients to TensorVariable
This could cause spurious disconnected errors, because the tensorified variable was not in the graph of the cost
1 parent 3082ed5 commit 4fa414c

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

pytensor/gradient.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -494,22 +494,25 @@ def Lop(
494494
coordinates of the tensor elements.
495495
If `f` is a list/tuple, then return a list/tuple with the results.
496496
"""
497+
from pytensor.tensor import as_tensor_variable
498+
497499
if not isinstance(eval_points, list | tuple):
498-
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
499-
else:
500-
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
500+
eval_points = [eval_points]
501+
_eval_points = [
502+
x if isinstance(x, Variable) else as_tensor_variable(x) for x in eval_points
503+
]
501504

502505
if not isinstance(f, list | tuple):
503-
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
504-
else:
505-
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
506+
f = [f]
507+
_f = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in f]
506508

507509
grads = list(_eval_points)
508510

511+
using_list = isinstance(wrt, list)
512+
using_tuple = isinstance(wrt, tuple)
509513
if not isinstance(wrt, list | tuple):
510-
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
511-
else:
512-
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
514+
wrt = [wrt]
515+
_wrt = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in wrt]
513516

514517
assert len(_f) == len(grads)
515518
known = dict(zip(_f, grads, strict=True))
@@ -523,8 +526,6 @@ def Lop(
523526
return_disconnected=return_disconnected,
524527
)
525528

526-
using_list = isinstance(wrt, list)
527-
using_tuple = isinstance(wrt, tuple)
528529
return as_list_or_tuple(using_list, using_tuple, ret)
529530

530531

tests/test_gradient.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DisconnectedType,
1212
GradClip,
1313
GradScale,
14+
Lop,
1415
NullTypeGradError,
1516
Rop,
1617
UndefinedGrad,
@@ -32,6 +33,7 @@
3233
from pytensor.graph.null_type import NullType
3334
from pytensor.graph.op import Op
3435
from pytensor.graph.traversal import graph_inputs
36+
from pytensor.scalar import float64
3537
from pytensor.scan.op import Scan
3638
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh
3739
from pytensor.tensor.math import sum as pt_sum
@@ -1207,3 +1209,13 @@ def test_multiple_wrt(self):
12071209
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
12081210
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
12091211
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
1212+
1213+
1214+
def test_scalar_Lop():
1215+
xtm1 = float64("xtm1")
1216+
xt = xtm1**2
1217+
1218+
dout_dxt = float64("dout_dxt")
1219+
dout_dxtm1 = Lop(xt, wrt=xtm1, eval_points=dout_dxt)
1220+
assert dout_dxtm1.type == dout_dxt.type
1221+
assert dout_dxtm1.eval({xtm1: 3.0, dout_dxt: 1.5}) == 2 * 3.0 * 1.5

0 commit comments

Comments
 (0)