From a749d88738dea3be584c039076b3b4c2cacff8b0 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 7 Jun 2025 22:21:37 +0800 Subject: [PATCH 1/2] More robust shape check for `grad` fallback in `jacobian` --- pytensor/gradient.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 04572b29d0..99dae108eb 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise else: wrt = [wrt] - if expression.ndim == 0: + if all(expression.type.broadcastable): # expression is just a scalar, use grad return as_list_or_tuple( using_list, using_tuple, grad( - expression, + expression.squeeze(), wrt, consider_constant=consider_constant, disconnected_inputs=disconnected_inputs, From 818b641cf3a263700b478265660234790f2e3e5f Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 7 Jun 2025 22:27:52 +0800 Subject: [PATCH 2/2] Update scalar test --- tests/test_gradient.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 24f5964c92..9673f8338e 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -30,6 +30,7 @@ from pytensor.graph.basic import Apply, graph_inputs from pytensor.graph.null_type import NullType from pytensor.graph.op import Op +from pytensor.scan.op import Scan from pytensor.tensor.math import add, dot, exp, sigmoid, sqr, tanh from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.random import RandomStream @@ -1036,6 +1037,17 @@ def test_jacobian_scalar(): vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) + # test when input is a shape (1,) vector -- should still be treated as a scalar + Jx = jacobian(y[None], x) + f = pytensor.function([x], Jx) + + # Ensure we hit the scalar grad case (doesn't use scan) + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Scan) for node in nodes) + + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + assert np.allclose(f(vx), 2) + # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple)