1+ from collections .abc import Sequence
12from copy import copy
23
34from scipy .optimize import minimize as scipy_minimize
5+ from scipy .optimize import root as scipy_root
46
5- from pytensor import function , graph_replace
6- from pytensor .gradient import grad
7+ from pytensor import Variable , function , graph_replace
8+ from pytensor .gradient import grad , jacobian
79from pytensor .graph import Apply , Constant , FunctionGraph
810from pytensor .graph .basic import truncated_graph_inputs
911from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1012from pytensor .scalar import bool as scalar_bool
13+ from pytensor .tensor .basic import atleast_2d
14+ from pytensor .tensor .slinalg import solve
15+ from pytensor .tensor .variable import TensorVariable
1116
1217
13- class MinimizeOp (Op , HasInnerGraph ):
14- def __init__ (
15- self ,
16- x ,
17- * args ,
18- output ,
19- method = "BFGS" ,
20- jac = True ,
21- hess = False ,
22- hessp = False ,
23- options : dict | None = None ,
24- debug : bool = False ,
25- ):
26- self .fgraph = FunctionGraph ([x , * args ], [output ])
18+ class ScipyWrapperOp (Op , HasInnerGraph ):
19+ """Shared logic for scipy optimization ops"""
2720
28- if jac :
29- grad_wrt_x = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
30- self .fgraph .add_output (grad_wrt_x )
31-
32- self .jac = jac
33- self .hess = hess
34- self .hessp = hessp
35-
36- self .method = method
37- self .options = options if options is not None else {}
38- self .debug = debug
39- self ._fn = None
40- self ._fn_wrapped = None
21+ __props__ = ("method" , "debug" )
4122
4223 def build_fn (self ):
24+ """
25+ This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
26+ wrapper function logic is there to handle this.
27+ """
28+ # TODO: Introduce rewrites to change MinimizeOp to MinimizeScalarOp and RootOp to RootScalarOp
29+ # when x is scalar. That will remove the need for the wrapper.
30+
4331 outputs = self .inner_outputs
4432 if len (outputs ) == 1 :
4533 outputs = outputs [0 ]
4634 self ._fn = fn = function (self .inner_inputs , outputs )
47- self . fgraph = (
48- fn . maker . fgraph
49- ) # So we see the compiled graph ater the first call
35+
36+ # Do this reassignment to see the compiled graph in the dprint
37+ self . fgraph = fn . maker . fgraph
5038
5139 if self .inner_inputs [0 ].type .shape == ():
52- # Work-around for scipy changing the type of x
40+
5341 def fn_wrapper (x , * args ):
5442 return fn (x .squeeze (), * args )
5543
@@ -90,21 +78,51 @@ def prepare_node(
9078 impl : str | None ,
9179 ):
9280 """Trigger the compilation of the inner fgraph so it shows in the dprint before the first call"""
93- # TODO: Implemet this method
81+ self . build_fn ()
9482
9583 def make_node (self , * inputs ):
9684 assert len (inputs ) == len (self .inner_inputs )
9785
9886 return Apply (
99- self , inputs , [self .inner_outputs [0 ].type (), scalar_bool ("success" )]
87+ self , inputs , [self .inner_inputs [0 ].type (), scalar_bool ("success" )]
10088 )
10189
90+
91+ class MinimizeOp (ScipyWrapperOp ):
92+ __props__ = ("method" , "jac" , "hess" , "hessp" , "debug" )
93+
94+ def __init__ (
95+ self ,
96+ x ,
97+ * args ,
98+ objective ,
99+ method = "BFGS" ,
100+ jac = True ,
101+ hess = False ,
102+ hessp = False ,
103+ options : dict | None = None ,
104+ debug : bool = False ,
105+ ):
106+ self .fgraph = FunctionGraph ([x , * args ], [objective ])
107+
108+ if jac :
109+ grad_wrt_x = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
110+ self .fgraph .add_output (grad_wrt_x )
111+
112+ self .jac = jac
113+ self .hess = hess
114+ self .hessp = hessp
115+
116+ self .method = method
117+ self .options = options if options is not None else {}
118+ self .debug = debug
119+ self ._fn = None
120+ self ._fn_wrapped = None
121+
102122 def perform (self , node , inputs , outputs ):
103123 f = self .fn_wrapped
104124 x0 , * args = inputs
105125
106- # print(f(*inputs))
107-
108126 res = scipy_minimize (
109127 fun = f ,
110128 jac = self .jac ,
@@ -113,8 +131,10 @@ def perform(self, node, inputs, outputs):
113131 method = self .method ,
114132 ** self .options ,
115133 )
134+
116135 if self .debug :
117136 print (res )
137+
118138 outputs [0 ][0 ] = res .x
119139 outputs [1 ][0 ] = res .success
120140
@@ -128,16 +148,12 @@ def L_op(self, inputs, outputs, output_grads):
128148
129149 inner_grads = grad (inner_fx , [inner_x , * inner_args ])
130150
131- # TODO: Does clone replace do what we want? It might need a merge optimization pass afterwards
132151 replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
152+
133153 grad_f_wrt_x_star , * grad_f_wrt_args = graph_replace (
134154 inner_grads , replace = replace
135155 )
136156
137- # # TODO: If scipy optimizer uses hessian (or hessp), just store it from the inner function
138- # inner_hess = jacobian(inner_fx, inner_args)
139- # hess_f_x = clone_replace(inner_hess, replace=replace)
140-
141157 grad_wrt_args = [
142158 - grad_f_wrt_arg / grad_f_wrt_x_star * output_grad
143159 for grad_f_wrt_arg in grad_f_wrt_args
@@ -192,9 +208,108 @@ def minimize(
192208 ]
193209
194210 minimize_op = MinimizeOp (
195- x , * args , output = objective , method = method , jac = jac , debug = debug , options = options
211+ x ,
212+ * args ,
213+ objective = objective ,
214+ method = method ,
215+ jac = jac ,
216+ debug = debug ,
217+ options = options ,
196218 )
219+
197220 return minimize_op (x , * args )
198221
199222
200- __all__ = ["minimize" ]
223+ class RootOp (ScipyWrapperOp ):
224+ __props__ = ("method" , "jac" , "debug" )
225+
226+ def __init__ (
227+ self ,
228+ variables ,
229+ * args ,
230+ equations ,
231+ method = "hybr" ,
232+ jac = True ,
233+ options : dict | None = None ,
234+ debug : bool = False ,
235+ ):
236+ self .fgraph = FunctionGraph ([variables , * args ], [equations ])
237+
238+ if jac :
239+ jac_wrt_x = jacobian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
240+ self .fgraph .add_output (atleast_2d (jac_wrt_x ))
241+
242+ self .jac = jac
243+
244+ self .method = method
245+ self .options = options if options is not None else {}
246+ self .debug = debug
247+ self ._fn = None
248+ self ._fn_wrapped = None
249+
250+ def perform (self , node , inputs , outputs ):
251+ f = self .fn_wrapped
252+ variables , * args = inputs
253+
254+ res = scipy_root (
255+ fun = f ,
256+ jac = self .jac ,
257+ x0 = variables ,
258+ args = tuple (args ),
259+ method = self .method ,
260+ ** self .options ,
261+ )
262+
263+ if self .debug :
264+ print (res )
265+
266+ outputs [0 ][0 ] = res .x
267+ outputs [1 ][0 ] = res .success
268+
269+ def L_op (
270+ self ,
271+ inputs : Sequence [Variable ],
272+ outputs : Sequence [Variable ],
273+ output_grads : Sequence [Variable ],
274+ ) -> list [Variable ]:
275+ # TODO: Broken
276+ x , * args = inputs
277+ x_star , success = outputs
278+ output_grad , _ = output_grads
279+
280+ inner_x , * inner_args = self .fgraph .inputs
281+ inner_fx = self .fgraph .outputs [0 ]
282+
283+ inner_jac = jacobian (inner_fx , [inner_x , * inner_args ])
284+
285+ replace = dict (zip (self .fgraph .inputs , (x_star , * args ), strict = True ))
286+ jac_f_wrt_x_star , * jac_f_wrt_args = graph_replace (inner_jac , replace = replace )
287+
288+ jac_wrt_args = solve (- jac_f_wrt_x_star , output_grad )
289+
290+ return [x .zeros_like (), jac_wrt_args ]
291+
292+
293+ def root (
294+ equations : TensorVariable ,
295+ variables : TensorVariable ,
296+ method : str = "hybr" ,
297+ jac : bool = True ,
298+ debug : bool = False ,
299+ ):
300+ """Find roots of a system of equations using scipy.optimize.root."""
301+
302+ args = [
303+ arg
304+ for arg in truncated_graph_inputs ([equations ], [variables ])
305+ if (arg is not variables and not isinstance (arg , Constant ))
306+ ]
307+
308+ root_op = RootOp (
309+ variables , * args , equations = equations , method = method , jac = jac , debug = debug
310+ )
311+
312+ return root_op (variables , * args )
313+
314+
315+ __all__ = ["minimize" , "root" ]
0 commit comments