22from copy import copy
33from typing import cast
44
5+ import numpy as np
56from scipy .optimize import minimize as scipy_minimize
67from scipy .optimize import root as scipy_root
78
89from pytensor import Variable , function , graph_replace
9- from pytensor .gradient import DisconnectedType , grad , jacobian
10+ from pytensor .gradient import grad , jacobian
1011from pytensor .graph import Apply , Constant , FunctionGraph
11- from pytensor .graph .basic import truncated_graph_inputs
12+ from pytensor .graph .basic import graph_inputs , truncated_graph_inputs
1213from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1314from pytensor .scalar import bool as scalar_bool
15+ from pytensor .tensor import dot
1416from pytensor .tensor .basic import atleast_2d , concatenate , zeros_like
17+ from pytensor .tensor .blockwise import Blockwise
1518from pytensor .tensor .slinalg import solve
1619from pytensor .tensor .variable import TensorVariable
1720
@@ -33,7 +36,7 @@ def build_fn(self):
3336 self ._fn = fn = function (self .inner_inputs , outputs )
3437
3538 # Do this reassignment to see the compiled graph in the dprint
36- self .fgraph = fn .maker .fgraph
39+ # self.fgraph = fn.maker.fgraph
3740
3841 if self .inner_inputs [0 ].type .shape == ():
3942
@@ -128,11 +131,11 @@ def perform(self, node, inputs, outputs):
128131 x0 = x0 ,
129132 args = tuple (args ),
130133 method = self .method ,
131- ** self .options ,
134+ ** self .optimizer_kwargs ,
132135 )
133136
134137 outputs [0 ][0 ] = res .x
135- outputs [1 ][0 ] = res .success
138+ outputs [1 ][0 ] = np . bool_ ( res .success )
136139
137140 def L_op (self , inputs , outputs , output_grads ):
138141 x , * args = inputs
@@ -158,26 +161,22 @@ def L_op(self, inputs, outputs, output_grads):
158161
159162 df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], replace = replace )
160163
161- grad_wrt_args_vector = solve (- df_dtheta_star , df_dx_star )
164+ grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
162165
163166 cursor = 0
164167 grad_wrt_args = []
165168
166- for output_grad , arg in zip ( output_grads , args , strict = True ) :
169+ for arg in args :
167170 arg_shape = arg .shape
168171 arg_size = arg_shape .prod ()
169- arg_grad = grad_wrt_args_vector [cursor : cursor + arg_size ].reshape (
170- arg_shape
172+ arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
173+ ( * x_star . shape , * arg_shape )
171174 )
172175
173- grad_wrt_args .append (
174- arg_grad * output_grad
175- if not isinstance (output_grad .type , DisconnectedType )
176- else DisconnectedType ()
177- )
176+ grad_wrt_args .append (dot (output_grad , arg_grad ))
178177 cursor += arg_size
179178
180- return [x . zeros_like (), * grad_wrt_args ]
179+ return [zeros_like (x ), * grad_wrt_args ]
181180
182181
183182def minimize (
@@ -217,7 +216,7 @@ def minimize(
217216 """
218217 args = [
219218 arg
220- for arg in truncated_graph_inputs ([objective ], [x ])
219+ for arg in graph_inputs ([objective ], [x ])
221220 if (arg is not x and not isinstance (arg , Constant ))
222221 ]
223222
@@ -230,7 +229,18 @@ def minimize(
230229 optimizer_kwargs = optimizer_kwargs ,
231230 )
232231
233- return minimize_op (x , * args )
232+ input_core_ndim = [var .ndim for var in minimize_op .inner_inputs ]
233+ input_signatures = [
234+ f'({ "," .join (f"i{ i } { n } " for n in range (ndim ))} )'
235+ for i , ndim in enumerate (input_core_ndim )
236+ ]
237+
238+ # Output dimensions are always the same as the first input (the initial values for the optimizer),
239+ # then a scalar for the success flag
240+ output_signatures = [input_signatures [0 ], "()" ]
241+
242+ signature = f"{ ',' .join (input_signatures )} ->{ ',' .join (output_signatures )} "
243+ return Blockwise (minimize_op , signature = signature )(x , * args )
234244
235245
236246class RootOp (ScipyWrapperOp ):
0 commit comments