Skip to content

Commit 7b53914

Browse files
Code cleanup
1 parent 08b8c3c commit 7b53914

File tree

1 file changed

+59
-18
lines changed

1 file changed

+59
-18
lines changed

pytensor/tensor/optimize.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from scipy.optimize import minimize as scipy_minimize
44

5-
from pytensor import function
5+
from pytensor import function, graph_replace
66
from pytensor.gradient import grad
7-
from pytensor.graph import Apply, Constant, FunctionGraph, clone_replace
7+
from pytensor.graph import Apply, Constant, FunctionGraph
88
from pytensor.graph.basic import truncated_graph_inputs
9-
from pytensor.graph.op import HasInnerGraph, Op
9+
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1010
from pytensor.scalar import bool as scalar_bool
1111

1212

@@ -17,7 +17,9 @@ def __init__(
1717
*args,
1818
output,
1919
method="BFGS",
20-
jac=False,
20+
jac=True,
21+
hess=False,
22+
hessp=False,
2123
options: dict | None = None,
2224
debug: bool = False,
2325
):
@@ -28,7 +30,9 @@ def __init__(
2830
self.fgraph.add_output(grad_wrt_x)
2931

3032
self.jac = jac
31-
# self.hess = hess
33+
self.hess = hess
34+
self.hessp = hessp
35+
3236
self.method = method
3337
self.options = options if options is not None else {}
3438
self.debug = debug
@@ -78,14 +82,19 @@ def clone(self):
7882
copy_op.fgraph = self.fgraph.clone()
7983
return copy_op
8084

81-
# def prepare_node():
82-
# # ... trigger the compilation of the inner fgraph so it shows in the dprint before the first call
83-
# ...
85+
def prepare_node(
86+
self,
87+
node: Apply,
88+
storage_map: StorageMapType | None,
89+
compute_map: ComputeMapType | None,
90+
impl: str | None,
91+
):
92+
"""Trigger the compilation of the inner fgraph so it shows in the dprint before the first call"""
93+
# TODO: Implemet this method
8494

8595
def make_node(self, *inputs):
86-
# print(inputs)
8796
assert len(inputs) == len(self.inner_inputs)
88-
# Assert type is correct.
97+
8998
return Apply(
9099
self, inputs, [self.inner_outputs[0].type(), scalar_bool("success")]
91100
)
@@ -114,18 +123,14 @@ def L_op(self, inputs, outputs, output_grads):
114123
x_star, success = outputs
115124
output_grad, _ = output_grads
116125

117-
# x_root, stats = root(func, x0, args=[arg], tol=1e-8)
118-
119126
inner_x, *inner_args = self.fgraph.inputs
120127
inner_fx = self.fgraph.outputs[0]
121128

122-
# f_x_star = clone_replace(inner_fx, replace={inner_x: x_star})
123-
124129
inner_grads = grad(inner_fx, [inner_x, *inner_args])
125130

126131
# TODO: Does clone replace do what we want? It might need a merge optimization pass afterwards
127132
replace = dict(zip(self.fgraph.inputs, (x_star, *args), strict=True))
128-
grad_f_wrt_x_star, *grad_f_wrt_args = clone_replace(
133+
grad_f_wrt_x_star, *grad_f_wrt_args = graph_replace(
129134
inner_grads, replace=replace
130135
)
131136

@@ -142,16 +147,52 @@ def L_op(self, inputs, outputs, output_grads):
142147

143148

144149
def minimize(
145-
objective, x, jac: bool = True, debug: bool = False, options: dict | None = None
150+
objective,
151+
x,
152+
method: str = "BFGS",
153+
jac: bool = True,
154+
debug: bool = False,
155+
options: dict | None = None,
146156
):
157+
"""
158+
Minimize a scalar objective function using scipy.optimize.minimize.
159+
160+
Parameters
161+
----------
162+
objective : TensorVariable
163+
The objective function to minimize. This should be a pytensor variable representing a scalar value.
164+
165+
x : TensorVariable
166+
The variable with respect to which the objective function is minimized. It must be an input to the
167+
computational graph of `objective`.
168+
169+
method : str, optional
170+
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
171+
172+
jac : bool, optional
173+
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
174+
Default is True.
175+
176+
debug : bool, optional
177+
If True, prints raw scipy result after optimization. Default is False.
178+
179+
**optimizer_kwargs
180+
Additional keyword arguments to pass to scipy.optimize.minimize
181+
182+
Returns
183+
-------
184+
TensorVariable
185+
The optimized value of x that minimizes the objective function.
186+
187+
"""
147188
args = [
148189
arg
149190
for arg in truncated_graph_inputs([objective], [x])
150191
if (arg is not x and not isinstance(arg, Constant))
151192
]
152-
# print(args)
193+
153194
minimize_op = MinimizeOp(
154-
x, *args, output=objective, jac=jac, debug=debug, options=options
195+
x, *args, output=objective, method=method, jac=jac, debug=debug, options=options
155196
)
156197
return minimize_op(x, *args)
157198

0 commit comments

Comments
 (0)