11from collections .abc import Sequence
22from copy import copy
3+ from typing import cast
34
45from scipy .optimize import minimize as scipy_minimize
56from scipy .optimize import root as scipy_root
1011from pytensor .graph .basic import truncated_graph_inputs
1112from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1213from pytensor .scalar import bool as scalar_bool
13- from pytensor .tensor .basic import atleast_2d , concatenate
14+ from pytensor .tensor .basic import atleast_2d , concatenate , zeros_like
1415from pytensor .tensor .slinalg import solve
1516from pytensor .tensor .variable import TensorVariable
1617
1718
1819class ScipyWrapperOp (Op , HasInnerGraph ):
1920 """Shared logic for scipy optimization ops"""
2021
21- __props__ = ("method" , "debug" )
22-
2322 def build_fn (self ):
2423 """
2524 This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
@@ -93,28 +92,30 @@ class MinimizeOp(ScipyWrapperOp):
9392
9493 def __init__ (
9594 self ,
96- x ,
97- * args ,
98- objective ,
99- method = "BFGS" ,
100- jac = True ,
101- hess = False ,
102- hessp = False ,
103- options : dict | None = None ,
95+ x : Variable ,
96+ * args : Variable ,
97+ objective : Variable ,
98+ method : str = "BFGS" ,
99+ jac : bool = True ,
100+ hess : bool = False ,
101+ hessp : bool = False ,
102+ optimizer_kwargs : dict | None = None ,
104103 debug : bool = False ,
105104 ):
106105 self .fgraph = FunctionGraph ([x , * args ], [objective ])
107106
108107 if jac :
109- grad_wrt_x = grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
108+ grad_wrt_x = cast (
109+ Variable , grad (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
110+ )
110111 self .fgraph .add_output (grad_wrt_x )
111112
112113 self .jac = jac
113114 self .hess = hess
114115 self .hessp = hessp
115116
116117 self .method = method
117- self .options = options if options is not None else {}
118+ self .optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
118119 self .debug = debug
119120 self ._fn = None
120121 self ._fn_wrapped = None
@@ -132,9 +133,6 @@ def perform(self, node, inputs, outputs):
132133 ** self .options ,
133134 )
134135
135- if self .debug :
136- print (res )
137-
138136 outputs [0 ][0 ] = res .x
139137 outputs [1 ][0 ] = res .success
140138
@@ -185,12 +183,12 @@ def L_op(self, inputs, outputs, output_grads):
185183
186184
187185def minimize (
188- objective ,
189- x ,
186+ objective : TensorVariable ,
187+ x : TensorVariable ,
190188 method : str = "BFGS" ,
191189 jac : bool = True ,
192190 debug : bool = False ,
193- options : dict | None = None ,
191+ optimizer_kwargs : dict | None = None ,
194192):
195193 """
196194 Minimize a scalar objective function using scipy.optimize.minimize.
@@ -214,7 +212,7 @@ def minimize(
214212 debug : bool, optional
215213 If True, prints raw scipy result after optimization. Default is False.
216214
217- ** optimizer_kwargs
215+ optimizer_kwargs
218216 Additional keyword arguments to pass to scipy.optimize.minimize
219217
220218 Returns
@@ -236,7 +234,7 @@ def minimize(
236234 method = method ,
237235 jac = jac ,
238236 debug = debug ,
239- options = options ,
237+ optimizer_kwargs = optimizer_kwargs ,
240238 )
241239
242240 return minimize_op (x , * args )
@@ -247,12 +245,12 @@ class RootOp(ScipyWrapperOp):
247245
248246 def __init__ (
249247 self ,
250- variables ,
251- * args ,
252- equations ,
253- method = "hybr" ,
254- jac = True ,
255- options : dict | None = None ,
248+ variables : Variable ,
249+ * args : Variable ,
250+ equations : Variable ,
251+ method : str = "hybr" ,
252+ jac : bool = True ,
253+ optimizer_kwargs : dict | None = None ,
256254 debug : bool = False ,
257255 ):
258256 self .fgraph = FunctionGraph ([variables , * args ], [equations ])
@@ -264,7 +262,7 @@ def __init__(
264262 self .jac = jac
265263
266264 self .method = method
267- self .options = options if options is not None else {}
265+ self .optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {}
268266 self .debug = debug
269267 self ._fn = None
270268 self ._fn_wrapped = None
@@ -279,12 +277,9 @@ def perform(self, node, inputs, outputs):
279277 x0 = variables ,
280278 args = tuple (args ),
281279 method = self .method ,
282- ** self .options ,
280+ ** self .optimizer_kwargs ,
283281 )
284282
285- if self .debug :
286- print (res )
287-
288283 outputs [0 ][0 ] = res .x
289284 outputs [1 ][0 ] = res .success
290285
@@ -309,7 +304,7 @@ def L_op(
309304
310305 jac_wrt_args = solve (- jac_f_wrt_x_star , output_grad )
311306
312- return [x . zeros_like (), jac_wrt_args ]
307+ return [zeros_like (x ), jac_wrt_args ]
313308
314309
315310def root (
0 commit comments