Skip to content

Commit 3facb1c

Browse files
author
Jesse Grabowski
committed
Dont eagerly rewrite graph in minimize/root helpers
1 parent 468fd98 commit 3facb1c

File tree

2 files changed

+72
-23
lines changed

2 files changed

+72
-23
lines changed

pytensor/tensor/optimize.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66

77
import pytensor.scalar as ps
88
from pytensor.compile.function import function
9-
from pytensor.gradient import grad, grad_not_implemented, jacobian
10-
from pytensor.graph import rewrite_graph
9+
from pytensor.gradient import DisconnectedType, grad, grad_not_implemented, jacobian
1110
from pytensor.graph.basic import Apply, Constant
1211
from pytensor.graph.fg import FunctionGraph
1312
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
@@ -255,16 +254,31 @@ def scalar_implict_optimization_grads(
255254
output_grad: TensorVariable,
256255
fgraph: FunctionGraph,
257256
) -> list[TensorVariable | ScalarVariable]:
257+
inner_args_to_diff = []
258+
outer_args_to_diff = []
259+
for inner_arg, outer_arg in zip(inner_args, args):
260+
if inner_arg.type.dtype.startswith("float"):
261+
inner_args_to_diff.append(inner_arg)
262+
outer_args_to_diff.append(outer_arg)
263+
264+
if len(args) > 0 and not inner_args_to_diff:
265+
# No differentiable arguments, return disconnected gradients
266+
return [DisconnectedType()() for _ in args]
267+
258268
df_dx, *df_dthetas = grad(
259-
inner_fx, [inner_x, *inner_args], disconnected_inputs="ignore"
269+
inner_fx, [inner_x, *inner_args_to_diff], disconnected_inputs="ignore"
260270
)
261271

262272
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
263273
df_dx_star, *df_dthetas_stars = graph_replace([df_dx, *df_dthetas], replace=replace)
264274

275+
arg_to_grad = dict(zip(outer_args_to_diff, df_dthetas_stars))
276+
265277
grad_wrt_args = [
266-
(-df_dtheta_star / df_dx_star) * output_grad
267-
for df_dtheta_star in df_dthetas_stars
278+
(-arg_to_grad[arg] / df_dx_star) * output_grad
279+
if arg in arg_to_grad
280+
else DisconnectedType()()
281+
for arg in args
268282
]
269283

270284
return grad_wrt_args
@@ -317,10 +331,26 @@ def implict_optimization_grads(
317331
fgraph : FunctionGraph
318332
The function graph that contains the inputs and outputs of the optimization problem.
319333
"""
334+
335+
# There might be non-differentiable arguments along the compute path from the objective to the inputs. Notably,
336+
# integers often arise due to Shape ops called by pack/unpack. These will be given DisconnectedType gradients.
337+
# First, they are filtered out before calling jacobian.
338+
inner_args_to_diff = []
339+
outer_args_to_diff = []
340+
for inner_arg, outer_arg in zip(inner_args, args):
341+
if inner_arg.type.dtype.startswith("float"):
342+
inner_args_to_diff.append(inner_arg)
343+
outer_args_to_diff.append(outer_arg)
344+
345+
if len(args) > 0 and not inner_args_to_diff:
346+
# No differentiable arguments, return disconnected gradients
347+
return [DisconnectedType()() for _ in args]
348+
349+
# Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op.
320350
packed_inner_args, packed_arg_shapes, implicit_f = (
321351
_maybe_pack_input_variables_and_rewrite_objective(
322352
implicit_f,
323-
inner_args,
353+
inner_args_to_diff,
324354
)
325355
)
326356

@@ -331,9 +361,11 @@ def implict_optimization_grads(
331361
vectorize=use_vectorized_jac,
332362
)
333363

364+
# Replace inner inputs (abstract dummies) with outer inputs (the actual user-provided symbols)
365+
# at the solution point. From here on, the inner values should not be referenced.
334366
inner_to_outer_map = dict(zip(fgraph.inputs, (x_star, *args)))
335-
336367
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], inner_to_outer_map)
368+
337369
grad_wrt_args_packed = solve(-atleast_2d(df_dx_star), atleast_1d(df_dtheta_star))
338370

339371
if packed_arg_shapes is not None:
@@ -351,16 +383,23 @@ def implict_optimization_grads(
351383
grad_wrt_args_packed = grad_wrt_args_packed.squeeze(axis=0)
352384
grad_wrt_args = [grad_wrt_args_packed]
353385

354-
final_grads = [
355-
tensordot(output_grad, arg_grad, [[0], [0]])
356-
if arg_grad.ndim > 0 and output_grad.ndim > 0
357-
else arg_grad * output_grad
358-
for arg_grad in grad_wrt_args
359-
]
360-
final_grads = [
361-
scalar_from_tensor(g) if isinstance(arg.type, ScalarType) else g
362-
for arg, g in zip(args, final_grads)
363-
]
386+
arg_to_grad = dict(zip(outer_args_to_diff, grad_wrt_args))
387+
388+
final_grads = []
389+
for arg in args:
390+
arg_grad = arg_to_grad.get(arg, None)
391+
392+
if arg_grad is None:
393+
final_grads.append(DisconnectedType()())
394+
continue
395+
396+
if arg_grad.ndim > 0 and output_grad.ndim > 0:
397+
g = tensordot(output_grad, arg_grad, [[0], [0]])
398+
else:
399+
g = arg_grad * output_grad
400+
if isinstance(arg.type, ScalarType):
401+
g = scalar_from_tensor(g)
402+
final_grads.append(g)
364403

365404
return final_grads
366405

@@ -640,7 +679,7 @@ def _maybe_pack_input_variables_and_rewrite_objective(
640679
for xi, ui in zip(x, unpacked_output)
641680
},
642681
)
643-
objective = rewrite_graph(objective, include=("ShapeOpt", "canonicalize"))
682+
644683
return packed_input, packed_shapes, objective
645684

646685

tests/tensor/test_optimize.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import pytensor
55
import pytensor.tensor as pt
66
from pytensor import Variable, config, function
7-
from pytensor.gradient import NullTypeGradError, disconnected_type
7+
from pytensor.gradient import (
8+
DisconnectedInputError,
9+
NullTypeGradError,
10+
disconnected_type,
11+
)
812
from pytensor.graph import Apply, Op, Type
913
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
1014
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar
@@ -438,7 +442,11 @@ def test_optimize_grad_scalar_arg(optimize_op):
438442
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: np.e}), -1)
439443

440444

441-
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))
445+
@pytest.mark.parametrize(
446+
"optimize_op",
447+
(minimize, minimize_scalar, root, root_scalar),
448+
ids=["minimize", "minimize_scalar", "root", "root_scalar"],
449+
)
442450
def test_optimize_grad_disconnected_numerical_inp(optimize_op):
443451
x = scalar("x", dtype="float64")
444452
theta = scalar("theta", dtype="int64")
@@ -449,12 +457,14 @@ def test_optimize_grad_disconnected_numerical_inp(optimize_op):
449457
assert x0.owner.inputs[1] is theta
450458

451459
# This should technically raise, but does not right now
452-
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="raise")
453-
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
460+
with pytest.raises(DisconnectedInputError):
461+
pt.grad(x0, theta, disconnected_inputs="raise")
454462

455463
# This should work even if the previous one raised
456464
grad_wrt_theta = pt.grad(x0, theta, disconnected_inputs="ignore")
457-
np.testing.assert_allclose(grad_wrt_theta.eval({x: np.pi, theta: 5}), 0)
465+
np.testing.assert_allclose(
466+
grad_wrt_theta.eval({x: np.pi, theta: 5}, on_unused_input="ignore"), 0
467+
)
458468

459469

460470
@pytest.mark.parametrize("optimize_op", (minimize, minimize_scalar, root, root_scalar))

0 commit comments

Comments
 (0)