Skip to content

Commit b2d8bc2

Browse files
Do not coerce gradients to TensorVariable (#1685)
* Do not coerce gradients to TensorVariable This could cause spurious disconnected errors, because the tensorified variable was not in the graph of the cost * Type-consistent checks --------- Co-authored-by: jessegrabowski <jessegrabowski@gmail.com>
1 parent 945e979 commit b2d8bc2

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

pytensor/gradient.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -494,22 +494,25 @@ def Lop(
494494
coordinates of the tensor elements.
495495
If `f` is a list/tuple, then return a list/tuple with the results.
496496
"""
497-
if not isinstance(eval_points, list | tuple):
498-
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
499-
else:
500-
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
497+
from pytensor.tensor import as_tensor_variable
501498

502-
if not isinstance(f, list | tuple):
503-
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
504-
else:
505-
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
499+
if not isinstance(eval_points, Sequence):
500+
eval_points = [eval_points]
501+
_eval_points = [
502+
x if isinstance(x, Variable) else as_tensor_variable(x) for x in eval_points
503+
]
504+
505+
if not isinstance(f, Sequence):
506+
f = [f]
507+
_f = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in f]
506508

507509
grads = list(_eval_points)
508510

509-
if not isinstance(wrt, list | tuple):
510-
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
511-
else:
512-
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
511+
using_list = isinstance(wrt, list)
512+
using_tuple = isinstance(wrt, tuple)
513+
if not isinstance(wrt, Sequence):
514+
wrt = [wrt]
515+
_wrt = [x if isinstance(x, Variable) else as_tensor_variable(x) for x in wrt]
513516

514517
assert len(_f) == len(grads)
515518
known = dict(zip(_f, grads, strict=True))
@@ -523,8 +526,6 @@ def Lop(
523526
return_disconnected=return_disconnected,
524527
)
525528

526-
using_list = isinstance(wrt, list)
527-
using_tuple = isinstance(wrt, tuple)
528529
return as_list_or_tuple(using_list, using_tuple, ret)
529530

530531

tests/test_gradient.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DisconnectedType,
1212
GradClip,
1313
GradScale,
14+
Lop,
1415
NullTypeGradError,
1516
Rop,
1617
UndefinedGrad,
@@ -32,6 +33,7 @@
3233
from pytensor.graph.null_type import NullType
3334
from pytensor.graph.op import Op
3435
from pytensor.graph.traversal import graph_inputs
36+
from pytensor.scalar import float64
3537
from pytensor.scan.op import Scan
3638
from pytensor.tensor.math import add, dot, exp, outer, sigmoid, sqr, sqrt, tanh
3739
from pytensor.tensor.math import sum as pt_sum
@@ -1207,3 +1209,13 @@ def test_multiple_wrt(self):
12071209
hessp_x_eval, hessp_y_eval = hessp_fn(**test)
12081210
np.testing.assert_allclose(hessp_x_eval, [2, 4, 6])
12091211
np.testing.assert_allclose(hessp_y_eval, [-6, -4, -2])
1212+
1213+
1214+
def test_scalar_Lop():
1215+
xtm1 = float64("xtm1")
1216+
xt = xtm1**2
1217+
1218+
dout_dxt = float64("dout_dxt")
1219+
dout_dxtm1 = Lop(xt, wrt=xtm1, eval_points=dout_dxt)
1220+
assert dout_dxtm1.type == dout_dxt.type
1221+
assert dout_dxtm1.eval({xtm1: 3.0, dout_dxt: 1.5}) == 2 * 3.0 * 1.5

0 commit comments

Comments
 (0)