22
33from scipy .optimize import minimize as scipy_minimize
44
5- from pytensor import function
5+ from pytensor import function , graph_replace
66from pytensor .gradient import grad
7- from pytensor .graph import Apply , Constant , FunctionGraph , clone_replace
7+ from pytensor .graph import Apply , Constant , FunctionGraph
88from pytensor .graph .basic import truncated_graph_inputs
9- from pytensor .graph .op import HasInnerGraph , Op
9+ from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1010from pytensor .scalar import bool as scalar_bool
1111
1212
@@ -17,7 +17,9 @@ def __init__(
1717 * args ,
1818 output ,
1919 method = "BFGS" ,
20- jac = False ,
20+ jac = True ,
21+ hess = False ,
22+ hessp = False ,
2123 options : dict | None = None ,
2224 debug : bool = False ,
2325 ):
@@ -28,7 +30,9 @@ def __init__(
2830 self .fgraph .add_output (grad_wrt_x )
2931
3032 self .jac = jac
31- # self.hess = hess
33+ self .hess = hess
34+ self .hessp = hessp
35+
3236 self .method = method
3337 self .options = options if options is not None else {}
3438 self .debug = debug
@@ -78,14 +82,19 @@ def clone(self):
7882 copy_op .fgraph = self .fgraph .clone ()
7983 return copy_op
8084
81- # def prepare_node():
82- # # ... trigger the compilation of the inner fgraph so it shows in the dprint before the first call
83- # ...
85+ def prepare_node (
86+ self ,
87+ node : Apply ,
88+ storage_map : StorageMapType | None ,
89+ compute_map : ComputeMapType | None ,
90+ impl : str | None ,
91+ ):
92+ """Trigger the compilation of the inner fgraph so it shows in the dprint before the first call"""
93+ # TODO: Implemet this method
8494
8595 def make_node (self , * inputs ):
86- # print(inputs)
8796 assert len (inputs ) == len (self .inner_inputs )
88- # Assert type is correct.
97+
8998 return Apply (
9099 self , inputs , [self .inner_outputs [0 ].type (), scalar_bool ("success" )]
91100 )
@@ -114,18 +123,14 @@ def L_op(self, inputs, outputs, output_grads):
114123 x_star , success = outputs
115124 output_grad , _ = output_grads
116125
117- # x_root, stats = root(func, x0, args=[arg], tol=1e-8)
118-
119126 inner_x , * inner_args = self .fgraph .inputs
120127 inner_fx = self .fgraph .outputs [0 ]
121128
122- # f_x_star = clone_replace(inner_fx, replace={inner_x: x_star})
123-
124129 inner_grads = grad (inner_fx , [inner_x , * inner_args ])
125130
126131 # TODO: Does clone replace do what we want? It might need a merge optimization pass afterwards
127132 replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
128- grad_f_wrt_x_star , * grad_f_wrt_args = clone_replace (
133+ grad_f_wrt_x_star , * grad_f_wrt_args = graph_replace (
129134 inner_grads , replace = replace
130135 )
131136
@@ -142,16 +147,52 @@ def L_op(self, inputs, outputs, output_grads):
142147
143148
144149def minimize (
145- objective , x , jac : bool = True , debug : bool = False , options : dict | None = None
150+ objective ,
151+ x ,
152+ method : str = "BFGS" ,
153+ jac : bool = True ,
154+ debug : bool = False ,
155+ options : dict | None = None ,
146156):
157+ """
158+ Minimize a scalar objective function using scipy.optimize.minimize.
159+
160+ Parameters
161+ ----------
162+ objective : TensorVariable
163+ The objective function to minimize. This should be a pytensor variable representing a scalar value.
164+
165+ x : TensorVariable
166+ The variable with respect to which the objective function is minimized. It must be an input to the
167+ computational graph of `objective`.
168+
169+ method : str, optional
170+ The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
171+
172+ jac : bool, optional
173+ Whether to compute and use the gradient of teh objective function with respect to x for optimization.
174+ Default is True.
175+
176+ debug : bool, optional
177+ If True, prints raw scipy result after optimization. Default is False.
178+
179+ **optimizer_kwargs
180+ Additional keyword arguments to pass to scipy.optimize.minimize
181+
182+ Returns
183+ -------
184+ TensorVariable
185+ The optimized value of x that minimizes the objective function.
186+
187+ """
147188 args = [
148189 arg
149190 for arg in truncated_graph_inputs ([objective ], [x ])
150191 if (arg is not x and not isinstance (arg , Constant ))
151192 ]
152- # print(args)
193+
153194 minimize_op = MinimizeOp (
154- x , * args , output = objective , jac = jac , debug = debug , options = options
195+ x , * args , output = objective , method = method , jac = jac , debug = debug , options = options
155196 )
156197 return minimize_op (x , * args )
157198
0 commit comments