77import pytensor .scalar as ps
88from pytensor .compile .function import function
99from pytensor .gradient import grad , grad_not_implemented , jacobian
10+ from pytensor .graph import rewrite_graph
1011from pytensor .graph .basic import Apply , Constant
1112from pytensor .graph .fg import FunctionGraph
1213from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
1314from pytensor .graph .replace import graph_replace
14- from pytensor .graph .traversal import ancestors , truncated_graph_inputs
15+ from pytensor .graph .traversal import (
16+ ancestors ,
17+ explicit_graph_inputs ,
18+ truncated_graph_inputs ,
19+ )
1520from pytensor .scalar import ScalarType , ScalarVariable
21+ from pytensor .tensor import as_tensor_variable
1622from pytensor .tensor .basic import (
23+ atleast_1d ,
1724 atleast_2d ,
18- concatenate ,
1925 scalar_from_tensor ,
2026 tensor ,
2127 tensor_from_scalar ,
2228 zeros_like ,
2329)
24- from pytensor .tensor .math import dot
30+ from pytensor .tensor .math import tensordot
31+ from pytensor .tensor .reshape import pack , unpack
2532from pytensor .tensor .slinalg import solve
2633from pytensor .tensor .type import DenseTensorType
2734from pytensor .tensor .variable import TensorVariable , Variable
2835
2936
37+ def get_ins (x ):
38+ return list (explicit_graph_inputs (x ))
39+
40+
3041# scipy.optimize can be slow to import, and will not be used by most users
3142# We import scipy.optimize lazily inside optimization perform methods to avoid this.
3243optimize = None
@@ -143,36 +154,6 @@ def _find_optimization_parameters(
143154 ]
144155
145156
146- def _get_parameter_grads_from_vector (
147- grad_wrt_args_vector : TensorVariable ,
148- x_star : TensorVariable ,
149- args : Sequence [TensorVariable | ScalarVariable ],
150- output_grad : TensorVariable ,
151- ) -> list [TensorVariable | ScalarVariable ]:
152- """
153- Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
154- returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
155- """
156- cursor = 0
157- grad_wrt_args = []
158-
159- for arg in args :
160- arg_shape = arg .shape
161- arg_size = arg_shape .prod ()
162- arg_grad = grad_wrt_args_vector [:, cursor : cursor + arg_size ].reshape (
163- (* x_star .shape , * arg_shape )
164- )
165-
166- grad_wrt_arg = dot (output_grad , arg_grad )
167- if isinstance (arg .type , ScalarType ):
168- grad_wrt_arg = scalar_from_tensor (grad_wrt_arg )
169- grad_wrt_args .append (grad_wrt_arg )
170-
171- cursor += arg_size
172-
173- return grad_wrt_args
174-
175-
176157class ScipyWrapperOp (Op , HasInnerGraph ):
177158 """Shared logic for scipy optimization ops"""
178159
@@ -295,12 +276,14 @@ def scalar_implict_optimization_grads(
295276
296277
297278def implict_optimization_grads (
298- df_dx : TensorVariable ,
299- df_dtheta_columns : Sequence [ TensorVariable ] ,
279+ implicit_f : TensorVariable ,
280+ inner_x : TensorVariable ,
300281 args : Sequence [TensorVariable | ScalarVariable ],
282+ inner_args : Sequence [TensorVariable | ScalarVariable ],
301283 x_star : TensorVariable ,
302284 output_grad : TensorVariable ,
303285 fgraph : FunctionGraph ,
286+ use_vectorized_jac : bool ,
304287) -> list [TensorVariable | ScalarVariable ]:
305288 r"""
306289 Compute gradients of an optimization problem with respect to its parameters.
@@ -321,19 +304,15 @@ def implict_optimization_grads(
321304
322305 .. math::
323306
324- \frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
307+ \frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta),
308+ \theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
325309
326310 Note that this method assumes `f(x_star(theta), theta) = 0`; so it is not immediately applicable to a minimization
327311 problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective
328312 function, which *is* indeed zero at the minimum.
329313
330314 Parameters
331315 ----------
332- df_dx : Variable
333- The Jacobian of the objective function with respect to the variable `x`.
334- df_dtheta_columns : Sequence[Variable]
335- The Jacobians of the objective function with respect to the optimization parameters `theta`.
336- Each column (or columns) corresponds to a different parameter. Should be returned by pytensor.gradient.jacobian.
337316 args : Sequence[Variable]
338317 The optimization parameters `theta`.
339318 x_star : Variable
@@ -343,23 +322,52 @@ def implict_optimization_grads(
343322 fgraph : FunctionGraph
344323 The function graph that contains the inputs and outputs of the optimization problem.
345324 """
346- df_dtheta = concatenate (
347- [atleast_2d (jac_col , left = False ) for jac_col in df_dtheta_columns ],
348- axis = - 1 ,
325+ packed_inner_args , packed_arg_shapes , implicit_f = (
326+ _maybe_pack_input_variables_and_rewrite_objective (
327+ implicit_f ,
328+ inner_args ,
329+ )
349330 )
350331
351- replace = dict (zip (fgraph .inputs , (x_star , * args ), strict = True ))
352-
353- df_dx_star , df_dtheta_star = graph_replace (
354- [atleast_2d (df_dx ), df_dtheta ], replace = replace
332+ df_dx , df_dtheta = jacobian (
333+ implicit_f ,
334+ [inner_x , packed_inner_args ],
335+ disconnected_inputs = "ignore" ,
336+ vectorize = use_vectorized_jac ,
355337 )
356338
357- grad_wrt_args_vector = solve (- df_dx_star , df_dtheta_star )
358- grad_wrt_args = _get_parameter_grads_from_vector (
359- grad_wrt_args_vector , x_star , args , output_grad
360- )
339+ inner_to_outer_map = dict (zip (fgraph .inputs , (x_star , * args )))
361340
362- return grad_wrt_args
341+ df_dx_star , df_dtheta_star = graph_replace ([df_dx , df_dtheta ], inner_to_outer_map )
342+ grad_wrt_args_packed = solve (- atleast_2d (df_dx_star ), atleast_1d (df_dtheta_star ))
343+
344+ if packed_arg_shapes is not None :
345+ packed_shapes_from_outer = graph_replace (
346+ packed_arg_shapes , inner_to_outer_map , strict = False
347+ )
348+ grad_wrt_args = unpack (
349+ grad_wrt_args_packed ,
350+ packed_shapes = packed_shapes_from_outer ,
351+ axes = 0 if not all (inp .ndim == 0 for inp in (x_star , * args )) else None ,
352+ )
353+ else :
354+ # There might have been a dimension added when performing the solve. In that case, squeeze it out.
355+ if grad_wrt_args_packed .ndim > df_dtheta_star .ndim :
356+ grad_wrt_args_packed = grad_wrt_args_packed .squeeze (axis = 0 )
357+ grad_wrt_args = [grad_wrt_args_packed ]
358+
359+ final_grads = [
360+ tensordot (output_grad , arg_grad , [[0 ], [0 ]])
361+ if arg_grad .ndim > 0 and output_grad .ndim > 0
362+ else arg_grad * output_grad
363+ for arg_grad in grad_wrt_args
364+ ]
365+ final_grads = [
366+ scalar_from_tensor (g ) if isinstance (arg .type , ScalarType ) else g
367+ for arg , g in zip (args , final_grads )
368+ ]
369+
370+ return final_grads
363371
364372
365373class MinimizeScalarOp (ScipyScalarWrapperOp ):
@@ -580,6 +588,7 @@ def perform(self, node, inputs, outputs):
580588 def L_op (self , inputs , outputs , output_grads ):
581589 # TODO: Handle disconnected inputs
582590 x , * args = inputs
591+
583592 if non_supported_types := tuple (
584593 inp .type
585594 for inp in inputs
@@ -591,50 +600,76 @@ def L_op(self, inputs, outputs, output_grads):
591600 return [
592601 grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
593602 ]
603+
594604 x_star , _success = outputs
595605 output_grad , _ = output_grads
596606
597607 inner_x , * inner_args = self .fgraph .inputs
598608 inner_fx = self .fgraph .outputs [0 ]
599-
600609 implicit_f = grad (inner_fx , inner_x )
601610
602- df_dx , * df_dtheta_columns = jacobian (
603- implicit_f ,
604- [inner_x , * inner_args ],
605- disconnected_inputs = "ignore" ,
606- vectorize = self .use_vectorized_jac ,
607- )
608- grad_wrt_args = implict_optimization_grads (
609- df_dx = df_dx ,
610- df_dtheta_columns = df_dtheta_columns ,
611+ final_grads = implict_optimization_grads (
612+ implicit_f = implicit_f ,
613+ inner_x = inner_x ,
614+ inner_args = inner_args ,
611615 args = args ,
612616 x_star = x_star ,
613617 output_grad = output_grad ,
614618 fgraph = self .fgraph ,
619+ use_vectorized_jac = self .use_vectorized_jac ,
615620 )
616621
617- return [zeros_like (x ), * grad_wrt_args ]
622+ return [zeros_like (x ), * final_grads ]
623+
624+
625+ def _maybe_pack_input_variables_and_rewrite_objective (
626+ objective : TensorVariable ,
627+ x : TensorVariable | Sequence [TensorVariable ],
628+ ) -> tuple [TensorVariable , list [TensorVariable ] | None , TensorVariable ]:
629+ packed_shapes = None
630+
631+ if not isinstance (x , Sequence ):
632+ packed_input = x
633+ elif len (x ) == 1 :
634+ packed_input = x [0 ]
635+ else :
636+ packed_input , packed_shapes = pack (* x , axes = None )
637+ unpacked_output = unpack (packed_input , axes = None , packed_shapes = packed_shapes )
638+
639+ objective = graph_replace (
640+ objective ,
641+ {
642+ xi : ui .astype (xi .type .dtype )
643+ if not (isinstance (xi .type , ScalarType ))
644+ else scalar_from_tensor (ui .astype (xi .type .dtype ))
645+ for xi , ui in zip (x , unpacked_output )
646+ },
647+ )
648+ objective = rewrite_graph (objective , include = ("ShapeOpt" , "canonicalize" ))
649+ return packed_input , packed_shapes , objective
618650
619651
620652def minimize (
621653 objective : TensorVariable ,
622- x : TensorVariable ,
654+ x : TensorVariable | Sequence [ TensorVariable ] ,
623655 method : str = "BFGS" ,
624656 jac : bool = True ,
625657 hess : bool = False ,
626658 use_vectorized_jac : bool = False ,
627659 optimizer_kwargs : dict | None = None ,
628- ) -> tuple [TensorVariable , TensorVariable ]:
660+ ) -> (
661+ tuple [TensorVariable , TensorVariable ]
662+ | tuple [tuple [TensorVariable , ...], TensorVariable ]
663+ ):
629664 """
630665 Minimize a scalar objective function using scipy.optimize.minimize.
631666
632667 Parameters
633668 ----------
634669 objective : TensorVariable
635670 The objective function to minimize. This should be a pytensor variable representing a scalar value.
636- x: TensorVariable
637- The variable with respect to which the objective function is minimized. It must be an input to the
671+ x: TensorVariable or list of TensorVariable
672+ The variable or variables with respect to which the objective function is minimized. It must be an input to the
638673 computational graph of `objective`.
639674 method: str, optional
640675 The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
@@ -653,18 +688,23 @@ def minimize(
653688
654689 Returns
655690 -------
656- solution: TensorVariable
657- The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
691+ solution: TensorVariable or tuple of TensorVariable
692+ The optimized value of each of inputs `x` that minimizes `objective(x, *args)`. If the success flag
658693 is False, this will be the final state of the minimization routine, but not necessarily a minimum.
659694
660695 success: TensorVariable
661696 Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
662697 value, based on the requested convergence criteria.
663698 """
664- args = _find_optimization_parameters (objective , x )
699+ objective = as_tensor_variable (objective )
700+
701+ packed_input , packed_shapes , objective = (
702+ _maybe_pack_input_variables_and_rewrite_objective (objective , x )
703+ )
704+ args = _find_optimization_parameters (objective , packed_input )
665705
666706 minimize_op = MinimizeOp (
667- x ,
707+ packed_input ,
668708 * args ,
669709 objective = objective ,
670710 method = method ,
@@ -674,7 +714,11 @@ def minimize(
674714 optimizer_kwargs = optimizer_kwargs ,
675715 )
676716
677- solution , success = minimize_op (x , * args )
717+ solution , success = minimize_op (packed_input , * args )
718+
719+ if packed_shapes is not None :
720+ solution = unpack (solution , axes = None , packed_shapes = packed_shapes )
721+ # solution = rewrite_graph(solution, include=("ShapeOpt", "canonicalize", "stabilize"))
678722
679723 return solution , success
680724
@@ -961,39 +1005,30 @@ def L_op(self, inputs, outputs, output_grads):
9611005 return [
9621006 grad_not_implemented (self , i , inp , msg ) for i , inp in enumerate (inputs )
9631007 ]
1008+
9641009 x_star , _ = outputs
9651010 output_grad , _ = output_grads
9661011
9671012 inner_x , * inner_args = self .fgraph .inputs
9681013 inner_fx = self .fgraph .outputs [0 ]
9691014
970- df_dx = (
971- jacobian (inner_fx , inner_x , vectorize = self .use_vectorized_jac )
972- if not self .jac
973- else self .fgraph .outputs [1 ]
974- )
975- df_dtheta_columns = jacobian (
976- inner_fx ,
977- inner_args ,
978- disconnected_inputs = "ignore" ,
979- vectorize = self .use_vectorized_jac ,
980- )
981-
982- grad_wrt_args = implict_optimization_grads (
983- df_dx = df_dx ,
984- df_dtheta_columns = df_dtheta_columns ,
1015+ final_grads = implict_optimization_grads (
1016+ implicit_f = inner_fx ,
1017+ inner_x = inner_x ,
1018+ inner_args = inner_args ,
9851019 args = args ,
9861020 x_star = x_star ,
9871021 output_grad = output_grad ,
9881022 fgraph = self .fgraph ,
1023+ use_vectorized_jac = self .use_vectorized_jac ,
9891024 )
9901025
991- return [zeros_like (x ), * grad_wrt_args ]
1026+ return [zeros_like (x ), * final_grads ]
9921027
9931028
9941029def root (
9951030 equations : TensorVariable ,
996- variables : TensorVariable ,
1031+ variables : TensorVariable | Sequence [ TensorVariable ] ,
9971032 method : str = "hybr" ,
9981033 jac : bool = True ,
9991034 use_vectorized_jac : bool = False ,
@@ -1032,11 +1067,13 @@ def root(
10321067 success: TensorVariable
10331068 Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
10341069 """
1035-
1036- args = _find_optimization_parameters (equations , variables )
1070+ packed_variables , packed_shapes , equations = (
1071+ _maybe_pack_input_variables_and_rewrite_objective (equations , variables )
1072+ )
1073+ args = _find_optimization_parameters (equations , packed_variables )
10371074
10381075 root_op = RootOp (
1039- variables ,
1076+ packed_variables ,
10401077 * args ,
10411078 equations = equations ,
10421079 method = method ,
@@ -1045,7 +1082,10 @@ def root(
10451082 use_vectorized_jac = use_vectorized_jac ,
10461083 )
10471084
1048- solution , success = root_op (variables , * args )
1085+ solution , success = root_op (packed_variables , * args )
1086+ if packed_shapes is not None :
1087+ solution = unpack (solution , axes = None , packed_shapes = packed_shapes )
1088+ # rewrite_graph(solution, include=("ShapeOpt", "canonicalize", "stabilize"))
10491089
10501090 return solution , success
10511091
0 commit comments