99from scipy .optimize import root as scipy_root
1010from scipy .optimize import root_scalar as scipy_root_scalar
1111
12+ import pytensor .scalar as ps
1213from pytensor import Variable , function , graph_replace
1314from pytensor .gradient import grad , hessian , jacobian
1415from pytensor .graph import Apply , Constant , FunctionGraph
15- from pytensor .graph .basic import truncated_graph_inputs
16+ from pytensor .graph .basic import ancestors , truncated_graph_inputs
1617from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
17- from pytensor .scalar import bool as scalar_bool
18- from pytensor .tensor import dot
19- from pytensor .tensor .basic import atleast_2d , concatenate , zeros_like
18+ from pytensor .tensor .basic import (
19+ atleast_2d ,
20+ concatenate ,
21+ tensor ,
22+ tensor_from_scalar ,
23+ zeros_like ,
24+ )
25+ from pytensor .tensor .math import dot
2026from pytensor .tensor .slinalg import solve
2127from pytensor .tensor .variable import TensorVariable
2228
@@ -223,9 +229,31 @@ def make_node(self, *inputs):
223229 input .type == inner_input .type
224230 ), f"Input { input } does not match expected type { inner_input .type } "
225231
226- return Apply (
227- self , inputs , [self .inner_inputs [0 ].type (), scalar_bool ("success" )]
228- )
232+ return Apply (self , inputs , [self .inner_inputs [0 ].type (), ps .bool ("success" )])
233+
234+
235+ class ScipyScalarWrapperOp (ScipyWrapperOp ):
236+ def build_fn (self ):
237+ """
238+ This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
239+ wrapper function logic is there to handle this.
240+ """
241+
242+ # We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
243+ # we need to adjust the graph to work with what scipy will be passing into the inner function --
244+ # always scalar, and always float64
245+ x , * args = self .inner_inputs
246+ new_root_x = ps .float64 (name = "x_scalar" )
247+ new_x = tensor_from_scalar (new_root_x .astype (x .type .dtype ))
248+
249+ new_outputs = graph_replace (self .inner_outputs , {x : new_x })
250+
251+ self ._fn = fn = function ([new_root_x , * args ], new_outputs , trust_input = True )
252+
253+ # Do this reassignment to see the compiled graph in the dprint
254+ # self.fgraph = fn.maker.fgraph
255+
256+ self ._fn_wrapped = LRUCache1 (fn )
229257
230258
231259def scalar_implict_optimization_grads (
@@ -327,7 +355,7 @@ def implict_optimization_grads(
327355 return grad_wrt_args
328356
329357
330- class MinimizeScalarOp (ScipyWrapperOp ):
358+ class MinimizeScalarOp (ScipyScalarWrapperOp ):
331359 __props__ = ("method" ,)
332360
333361 def __init__ (
@@ -338,6 +366,14 @@ def __init__(
338366 method : str = "brent" ,
339367 optimizer_kwargs : dict | None = None ,
340368 ):
369+ if not x .ndim == 0 :
370+ raise ValueError (
371+ "The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
372+ )
373+ if not objective .ndim == 0 :
374+ raise ValueError (
375+ "The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
376+ )
341377 self .fgraph = FunctionGraph ([x , * args ], [objective ])
342378
343379 self .method = method
@@ -351,7 +387,7 @@ def perform(self, node, inputs, outputs):
351387
352388 # minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
353389 # the args of the objective function), but it is not used in the optimization.
354- _ , * args = inputs
390+ x0 , * args = inputs
355391
356392 res = scipy_minimize_scalar (
357393 fun = f .value ,
@@ -360,7 +396,7 @@ def perform(self, node, inputs, outputs):
360396 ** self .optimizer_kwargs ,
361397 )
362398
363- outputs [0 ][0 ] = np .array (res .x )
399+ outputs [0 ][0 ] = np .array (res .x , dtype = x0 . dtype )
364400 outputs [1 ][0 ] = np .bool_ (res .success )
365401
366402 def L_op (self , inputs , outputs , output_grads ):
@@ -423,6 +459,15 @@ def __init__(
423459 hessp : bool = False ,
424460 optimizer_kwargs : dict | None = None ,
425461 ):
462+ if not objective .ndim == 0 :
463+ raise ValueError (
464+ "The objective function must be a scalar (0-dimensional) tensor for minimize."
465+ )
466+ if not isinstance (x , Variable ) and x not in ancestors ([objective ]):
467+ raise ValueError (
468+ "The variable `x` must be an input to the computational graph of the objective function."
469+ )
470+
426471 self .fgraph = FunctionGraph ([x , * args ], [objective ])
427472
428473 if jac :
@@ -462,7 +507,7 @@ def perform(self, node, inputs, outputs):
462507
463508 f .clear_cache ()
464509
465- outputs [0 ][0 ] = res .x
510+ outputs [0 ][0 ] = res .x . astype ( x0 . dtype )
466511 outputs [1 ][0 ] = np .bool_ (res .success )
467512
468513 def L_op (self , inputs , outputs , output_grads ):
@@ -541,7 +586,7 @@ def minimize(
541586 return minimize_op (x , * args )
542587
543588
544- class RootScalarOp (ScipyWrapperOp ):
589+ class RootScalarOp (ScipyScalarWrapperOp ):
545590 __props__ = ("method" , "jac" , "hess" )
546591
547592 def __init__ (
@@ -554,6 +599,17 @@ def __init__(
554599 hess : bool = False ,
555600 optimizer_kwargs = None ,
556601 ):
602+ if not equation .ndim == 0 :
603+ raise ValueError (
604+ "The equation must be a scalar (0-dimensional) tensor for root_scalar."
605+ )
606+ if not isinstance (variables , Variable ) or variables not in ancestors (
607+ [equation ]
608+ ):
609+ raise ValueError (
610+ "The variable `variables` must be an input to the computational graph of the equation."
611+ )
612+
557613 self .fgraph = FunctionGraph ([variables , * args ], [equation ])
558614
559615 if jac :
@@ -673,6 +729,32 @@ def __init__(
673729 self ._fn = None
674730 self ._fn_wrapped = None
675731
732+ def build_fn (self ):
733+ outputs = self .inner_outputs
734+ variables , * args = self .inner_inputs
735+
736+ if variables .ndim > 0 :
737+ new_root_variables = variables
738+ new_outputs = outputs
739+ else :
740+ # If the user passes a scalar optimization problem to root, scipy will automatically upcast it to
741+ # a 1d array. The inner function needs to be adjusted to handle this.
742+ new_root_variables = tensor (
743+ name = "variables_vector" , shape = (1 ,), dtype = variables .type .dtype
744+ )
745+ new_variables = new_root_variables .squeeze ()
746+
747+ new_outputs = graph_replace (outputs , {variables : new_variables })
748+
749+ self ._fn = fn = function (
750+ [new_root_variables , * args ], new_outputs , trust_input = True
751+ )
752+
753+ # Do this reassignment to see the compiled graph in the dprint
754+ # self.fgraph = fn.maker.fgraph
755+
756+ self ._fn_wrapped = LRUCache1 (fn )
757+
676758 def perform (self , node , inputs , outputs ):
677759 f = self .fn_wrapped
678760 f .clear_cache ()
0 commit comments