Skip to content

Commit a749d88

Browse files
More robust shape check for grad fallback in jacobian
1 parent ff98ab8 commit a749d88

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,13 +2069,13 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
20692069
else:
20702070
wrt = [wrt]
20712071

2072-
if expression.ndim == 0:
2072+
if all(expression.type.broadcastable):
20732073
# expression is just a scalar, use grad
20742074
return as_list_or_tuple(
20752075
using_list,
20762076
using_tuple,
20772077
grad(
2078-
expression,
2078+
expression.squeeze(),
20792079
wrt,
20802080
consider_constant=consider_constant,
20812081
disconnected_inputs=disconnected_inputs,

0 commit comments

Comments
 (0)