Skip to content

Commit 92db737

Browse files
Add RootOp, fix gradient tests (they are failing)
1 parent 7b53914 commit 92db737

File tree

2 files changed

+223
-47
lines changed

2 files changed

+223
-47
lines changed

pytensor/tensor/optimize.py

Lines changed: 159 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,43 @@
1+
from collections.abc import Sequence
12
from copy import copy
23

34
from scipy.optimize import minimize as scipy_minimize
5+
from scipy.optimize import root as scipy_root
46

5-
from pytensor import function, graph_replace
6-
from pytensor.gradient import grad
7+
from pytensor import Variable, function, graph_replace
8+
from pytensor.gradient import grad, jacobian
79
from pytensor.graph import Apply, Constant, FunctionGraph
810
from pytensor.graph.basic import truncated_graph_inputs
911
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1012
from pytensor.scalar import bool as scalar_bool
13+
from pytensor.tensor.basic import atleast_2d
14+
from pytensor.tensor.slinalg import solve
15+
from pytensor.tensor.variable import TensorVariable
1116

1217

13-
class MinimizeOp(Op, HasInnerGraph):
14-
def __init__(
15-
self,
16-
x,
17-
*args,
18-
output,
19-
method="BFGS",
20-
jac=True,
21-
hess=False,
22-
hessp=False,
23-
options: dict | None = None,
24-
debug: bool = False,
25-
):
26-
self.fgraph = FunctionGraph([x, *args], [output])
18+
class ScipyWrapperOp(Op, HasInnerGraph):
19+
"""Shared logic for scipy optimization ops"""
2720

28-
if jac:
29-
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
30-
self.fgraph.add_output(grad_wrt_x)
31-
32-
self.jac = jac
33-
self.hess = hess
34-
self.hessp = hessp
35-
36-
self.method = method
37-
self.options = options if options is not None else {}
38-
self.debug = debug
39-
self._fn = None
40-
self._fn_wrapped = None
21+
__props__ = ("method", "debug")
4122

4223
def build_fn(self):
24+
"""
25+
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
26+
wrapper function logic is there to handle this.
27+
"""
28+
# TODO: Introduce rewrites to change MinimizeOp to MinimizeScalarOp and RootOp to RootScalarOp
29+
# when x is scalar. That will remove the need for the wrapper.
30+
4331
outputs = self.inner_outputs
4432
if len(outputs) == 1:
4533
outputs = outputs[0]
4634
self._fn = fn = function(self.inner_inputs, outputs)
47-
self.fgraph = (
48-
fn.maker.fgraph
49-
) # So we see the compiled graph ater the first call
35+
36+
# Do this reassignment to see the compiled graph in the dprint
37+
self.fgraph = fn.maker.fgraph
5038

5139
if self.inner_inputs[0].type.shape == ():
52-
# Work-around for scipy changing the type of x
40+
5341
def fn_wrapper(x, *args):
5442
return fn(x.squeeze(), *args)
5543

@@ -90,21 +78,51 @@ def prepare_node(
9078
impl: str | None,
9179
):
9280
"""Trigger the compilation of the inner fgraph so it shows in the dprint before the first call"""
93-
# TODO: Implemet this method
81+
self.build_fn()
9482

9583
def make_node(self, *inputs):
9684
assert len(inputs) == len(self.inner_inputs)
9785

9886
return Apply(
99-
self, inputs, [self.inner_outputs[0].type(), scalar_bool("success")]
87+
self, inputs, [self.inner_inputs[0].type(), scalar_bool("success")]
10088
)
10189

90+
91+
class MinimizeOp(ScipyWrapperOp):
92+
__props__ = ("method", "jac", "hess", "hessp", "debug")
93+
94+
def __init__(
95+
self,
96+
x,
97+
*args,
98+
objective,
99+
method="BFGS",
100+
jac=True,
101+
hess=False,
102+
hessp=False,
103+
options: dict | None = None,
104+
debug: bool = False,
105+
):
106+
self.fgraph = FunctionGraph([x, *args], [objective])
107+
108+
if jac:
109+
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
110+
self.fgraph.add_output(grad_wrt_x)
111+
112+
self.jac = jac
113+
self.hess = hess
114+
self.hessp = hessp
115+
116+
self.method = method
117+
self.options = options if options is not None else {}
118+
self.debug = debug
119+
self._fn = None
120+
self._fn_wrapped = None
121+
102122
def perform(self, node, inputs, outputs):
103123
f = self.fn_wrapped
104124
x0, *args = inputs
105125

106-
# print(f(*inputs))
107-
108126
res = scipy_minimize(
109127
fun=f,
110128
jac=self.jac,
@@ -113,8 +131,10 @@ def perform(self, node, inputs, outputs):
113131
method=self.method,
114132
**self.options,
115133
)
134+
116135
if self.debug:
117136
print(res)
137+
118138
outputs[0][0] = res.x
119139
outputs[1][0] = res.success
120140

@@ -128,16 +148,12 @@ def L_op(self, inputs, outputs, output_grads):
128148

129149
inner_grads = grad(inner_fx, [inner_x, *inner_args])
130150

131-
# TODO: Does clone replace do what we want? It might need a merge optimization pass afterwards
132151
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
152+
133153
grad_f_wrt_x_star, *grad_f_wrt_args = graph_replace(
134154
inner_grads, replace=replace
135155
)
136156

137-
# # TODO: If scipy optimizer uses hessian (or hessp), just store it from the inner function
138-
# inner_hess = jacobian(inner_fx, inner_args)
139-
# hess_f_x = clone_replace(inner_hess, replace=replace)
140-
141157
grad_wrt_args = [
142158
-grad_f_wrt_arg / grad_f_wrt_x_star * output_grad
143159
for grad_f_wrt_arg in grad_f_wrt_args
@@ -192,9 +208,108 @@ def minimize(
192208
]
193209

194210
minimize_op = MinimizeOp(
195-
x, *args, output=objective, method=method, jac=jac, debug=debug, options=options
211+
x,
212+
*args,
213+
objective=objective,
214+
method=method,
215+
jac=jac,
216+
debug=debug,
217+
options=options,
196218
)
219+
197220
return minimize_op(x, *args)
198221

199222

200-
__all__ = ["minimize"]
223+
class RootOp(ScipyWrapperOp):
224+
__props__ = ("method", "jac", "debug")
225+
226+
def __init__(
227+
self,
228+
variables,
229+
*args,
230+
equations,
231+
method="hybr",
232+
jac=True,
233+
options: dict | None = None,
234+
debug: bool = False,
235+
):
236+
self.fgraph = FunctionGraph([variables, *args], [equations])
237+
238+
if jac:
239+
jac_wrt_x = jacobian(self.fgraph.outputs[0], self.fgraph.inputs[0])
240+
self.fgraph.add_output(atleast_2d(jac_wrt_x))
241+
242+
self.jac = jac
243+
244+
self.method = method
245+
self.options = options if options is not None else {}
246+
self.debug = debug
247+
self._fn = None
248+
self._fn_wrapped = None
249+
250+
def perform(self, node, inputs, outputs):
251+
f = self.fn_wrapped
252+
variables, *args = inputs
253+
254+
res = scipy_root(
255+
fun=f,
256+
jac=self.jac,
257+
x0=variables,
258+
args=tuple(args),
259+
method=self.method,
260+
**self.options,
261+
)
262+
263+
if self.debug:
264+
print(res)
265+
266+
outputs[0][0] = res.x
267+
outputs[1][0] = res.success
268+
269+
def L_op(
270+
self,
271+
inputs: Sequence[Variable],
272+
outputs: Sequence[Variable],
273+
output_grads: Sequence[Variable],
274+
) -> list[Variable]:
275+
# TODO: Broken
276+
x, *args = inputs
277+
x_star, success = outputs
278+
output_grad, _ = output_grads
279+
280+
inner_x, *inner_args = self.fgraph.inputs
281+
inner_fx = self.fgraph.outputs[0]
282+
283+
inner_jac = jacobian(inner_fx, [inner_x, *inner_args])
284+
285+
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
286+
jac_f_wrt_x_star, *jac_f_wrt_args = graph_replace(inner_jac, replace=replace)
287+
288+
jac_wrt_args = solve(-jac_f_wrt_x_star, output_grad)
289+
290+
return [x.zeros_like(), jac_wrt_args]
291+
292+
293+
def root(
294+
equations: TensorVariable,
295+
variables: TensorVariable,
296+
method: str = "hybr",
297+
jac: bool = True,
298+
debug: bool = False,
299+
):
300+
"""Find roots of a system of equations using scipy.optimize.root."""
301+
302+
args = [
303+
arg
304+
for arg in truncated_graph_inputs([equations], [variables])
305+
if (arg is not variables and not isinstance(arg, Constant))
306+
]
307+
308+
root_op = RootOp(
309+
variables, *args, equations=equations, method=method, jac=jac, debug=debug
310+
)
311+
312+
return root_op(variables, *args)
313+
314+
315+
__all__ = ["minimize", "root"]

tests/tensor/test_optimize.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
22

3+
import pytensor
34
import pytensor.tensor as pt
45
from pytensor import config
5-
from pytensor.tensor.optimize import minimize
6+
from pytensor.tensor.optimize import minimize, root
67
from tests import unittest_tools as utt
78

89

@@ -19,6 +20,7 @@ def test_simple_minimize():
1920
out = (x - b * c) ** 2
2021

2122
minimized_x, success = minimize(out, x)
23+
minimized_x.dprint()
2224

2325
a_val = 2.0
2426
c_val = 3.0
@@ -43,7 +45,6 @@ def rosenbrock_shifted_scaled(x, a, b):
4345
b = pt.scalar("b")
4446

4547
objective = rosenbrock_shifted_scaled(x, a, b)
46-
4748
minimized_x, success = minimize(objective, x, method="BFGS")
4849

4950
a_val = 0.5
@@ -56,4 +57,64 @@ def rosenbrock_shifted_scaled(x, a, b):
5657
x_star_val, np.ones_like(x_star_val), atol=1e-6, rtol=1e-6
5758
)
5859

59-
utt.verify_grad(rosenbrock_shifted_scaled, [x0, a_val, b_val], eps=1e-6)
60+
def f(x, a, b):
61+
objective = rosenbrock_shifted_scaled(x, a, b)
62+
out = minimize(objective, x)[0]
63+
return out
64+
65+
utt.verify_grad(f, [x0, a_val, b_val], eps=1e-6)
66+
67+
68+
def test_root_simple():
69+
x = pt.scalar("x")
70+
a = pt.scalar("a")
71+
72+
def fn(x, a):
73+
return x + 2 * a * pt.cos(x)
74+
75+
f = fn(x, a)
76+
root_f, success = root(f, x)
77+
func = pytensor.function([x, a], [root_f, success])
78+
79+
x0 = 0.0
80+
a_val = 1.0
81+
solution, success = func(x0, a_val)
82+
83+
assert success
84+
np.testing.assert_allclose(solution, -1.02986653, atol=1e-6, rtol=1e-6)
85+
86+
def root_fn(x, a):
87+
f = fn(x, a)
88+
return root(f, x)[0]
89+
90+
utt.verify_grad(root_fn, [x0, a_val], eps=1e-6)
91+
92+
93+
def test_root_system_of_equations():
94+
x = pt.dvector("x")
95+
a = pt.dvector("a")
96+
b = pt.dvector("b")
97+
98+
f = pt.stack([a[0] * x[0] * pt.cos(x[1]) - b[0], x[0] * x[1] - a[1] * x[1] - b[1]])
99+
100+
root_f, success = root(f, x, debug=True)
101+
func = pytensor.function([x, a, b], [root_f, success])
102+
103+
x0 = np.array([1.0, 1.0])
104+
a_val = np.array([1.0, 1.0])
105+
b_val = np.array([4.0, 5.0])
106+
solution, success = func(x0, a_val, b_val)
107+
108+
assert success
109+
110+
np.testing.assert_allclose(
111+
solution, np.array([6.50409711, 0.90841421]), atol=1e-6, rtol=1e-6
112+
)
113+
114+
def root_fn(x, a, b):
115+
f = pt.stack(
116+
[a[0] * x[0] * pt.cos(x[1]) - b[0], x[0] * x[1] - a[1] * x[1] - b[1]]
117+
)
118+
return root(f, x)[0]
119+
120+
utt.verify_grad(root_fn, [x0, a_val, b_val], eps=1e-6)

0 commit comments

Comments
 (0)