Skip to content

Commit 21c10a1

Browse files
Implement optimize.minimize
1 parent 911c6a3 commit 21c10a1

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

pytensor/tensor/optimize.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from copy import copy
2+
3+
from scipy.optimize import minimize as scipy_minimize
4+
5+
from pytensor import function
6+
from pytensor.gradient import grad
7+
from pytensor.graph import Apply, Constant, FunctionGraph, clone_replace
8+
from pytensor.graph.basic import truncated_graph_inputs
9+
from pytensor.graph.op import HasInnerGraph, Op
10+
from pytensor.scalar import bool as scalar_bool
11+
12+
13+
class MinimizeOp(Op, HasInnerGraph):
14+
def __init__(
15+
self,
16+
x,
17+
*args,
18+
output,
19+
method="BFGS",
20+
jac=False,
21+
options: dict | None = None,
22+
debug: bool = False,
23+
):
24+
self.fgraph = FunctionGraph([x, *args], [output])
25+
26+
if jac:
27+
grad_wrt_x = grad(self.fgraph.outputs[0], self.fgraph.inputs[0])
28+
self.fgraph.add_output(grad_wrt_x)
29+
30+
self.jac = jac
31+
# self.hess = hess
32+
self.method = method
33+
self.options = options if options is not None else {}
34+
self.debug = debug
35+
self._fn = None
36+
self._fn_wrapped = None
37+
38+
def build_fn(self):
39+
outputs = self.inner_outputs
40+
if len(outputs) == 1:
41+
outputs = outputs[0]
42+
self._fn = fn = function(self.inner_inputs, outputs)
43+
self.fgraph = (
44+
fn.maker.fgraph
45+
) # So we see the compiled graph ater the first call
46+
47+
if self.inner_inputs[0].type.shape == ():
48+
# Work-around for scipy changing the type of x
49+
def fn_wrapper(x, *args):
50+
return fn(x.squeeze(), *args)
51+
52+
self._fn_wrapped = fn_wrapper
53+
else:
54+
self._fn_wrapped = fn
55+
56+
@property
57+
def fn(self):
58+
if self._fn is None:
59+
self.build_fn()
60+
return self._fn
61+
62+
@property
63+
def fn_wrapped(self):
64+
if self._fn_wrapped is None:
65+
self.build_fn()
66+
return self._fn_wrapped
67+
68+
@property
69+
def inner_inputs(self):
70+
return self.fgraph.inputs
71+
72+
@property
73+
def inner_outputs(self):
74+
return self.fgraph.outputs
75+
76+
def clone(self):
77+
copy_op = copy(self)
78+
copy_op.fgraph = self.fgraph.clone()
79+
return copy_op
80+
81+
# def prepare_node():
82+
# # ... trigger the compilation of the inner fgraph so it shows in the dprint before the first call
83+
# ...
84+
85+
def make_node(self, *inputs):
86+
# print(inputs)
87+
assert len(inputs) == len(self.inner_inputs)
88+
# Assert type is correct.
89+
return Apply(
90+
self, inputs, [self.inner_outputs[0].type(), scalar_bool("success")]
91+
)
92+
93+
def perform(self, node, inputs, outputs):
94+
f = self.fn_wrapped
95+
x0, *args = inputs
96+
97+
# print(f(*inputs))
98+
99+
res = scipy_minimize(
100+
fun=f,
101+
jac=self.jac,
102+
x0=x0,
103+
args=tuple(args),
104+
method=self.method,
105+
**self.options,
106+
)
107+
if self.debug:
108+
print(res)
109+
outputs[0][0] = res.x
110+
outputs[1][0] = res.success
111+
112+
def L_op(self, inputs, outputs, output_grads):
113+
x, *args = inputs
114+
x_star, success = outputs
115+
output_grad, _ = output_grads
116+
117+
# x_root, stats = root(func, x0, args=[arg], tol=1e-8)
118+
119+
inner_x, *inner_args = self.fgraph.inputs
120+
inner_fx = self.fgraph.outputs[0]
121+
122+
# f_x_star = clone_replace(inner_fx, replace={inner_x: x_star})
123+
124+
inner_grads = grad(inner_fx, [inner_x, *inner_args])
125+
126+
# TODO: Does clone replace do what we want? It might need a merge optimization pass afterwards
127+
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
128+
grad_f_wrt_x_star, *grad_f_wrt_args = clone_replace(
129+
inner_grads, replace=replace
130+
)
131+
132+
# # TODO: If scipy optimizer uses hessian (or hessp), just store it from the inner function
133+
# inner_hess = jacobian(inner_fx, inner_args)
134+
# hess_f_x = clone_replace(inner_hess, replace=replace)
135+
136+
grad_wrt_args = [
137+
-grad_f_wrt_arg / grad_f_wrt_x_star * output_grad
138+
for grad_f_wrt_arg in grad_f_wrt_args
139+
]
140+
141+
return [x.zeros_like(), *grad_wrt_args]
142+
143+
144+
def minimize(
145+
objective, x, jac: bool = True, debug: bool = False, options: dict | None = None
146+
):
147+
args = [
148+
arg
149+
for arg in truncated_graph_inputs([objective], [x])
150+
if (arg is not x and not isinstance(arg, Constant))
151+
]
152+
# print(args)
153+
minimize_op = MinimizeOp(
154+
x, *args, output=objective, jac=jac, debug=debug, options=options
155+
)
156+
return minimize_op(x, *args)
157+
158+
159+
__all__ = ["minimize"]

tests/tensor/test_optimize.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytensor.tensor as pt
2+
from pytensor.tensor.optimize import minimize
3+
4+
5+
def test_minimize():
6+
x = pt.scalar("x")
7+
a = pt.scalar("a")
8+
c = pt.scalar("c")
9+
10+
b = a * 2
11+
b.name = "b"
12+
out = (x - b * c) ** 2
13+
14+
minimized_x, success = minimize(out, x, debug=False)
15+
16+
a_val = 2
17+
c_val = 3
18+
19+
assert success
20+
assert minimized_x.eval({a: a_val, c: c_val, x: 0.0}) == (2 * a_val * c_val)
21+
22+
x_grad, a_grad, c_grad = pt.grad(minimized_x, [x, a, c])
23+
24+
assert x_grad.eval({x: 0.0}) == 0.0
25+
assert a_grad.eval({a: a_val, c: c_val, x: 0.0}) == 2 * c_val
26+
assert c_grad.eval({a: a_val, c: c_val, x: 0.0}) == 2 * a_val

0 commit comments

Comments
 (0)