Skip to content

Commit 33889f0

Browse files
Add specialized build_fn to each Op to handle scipy quirks
1 parent f06526e commit 33889f0

File tree

1 file changed

+94
-12
lines changed

1 file changed

+94
-12
lines changed

pytensor/tensor/optimize.py

Lines changed: 94 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,20 @@
99
from scipy.optimize import root as scipy_root
1010
from scipy.optimize import root_scalar as scipy_root_scalar
1111

12+
import pytensor.scalar as ps
1213
from pytensor import Variable, function, graph_replace
1314
from pytensor.gradient import grad, hessian, jacobian
1415
from pytensor.graph import Apply, Constant, FunctionGraph
15-
from pytensor.graph.basic import truncated_graph_inputs
16+
from pytensor.graph.basic import ancestors, truncated_graph_inputs
1617
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
17-
from pytensor.scalar import bool as scalar_bool
18-
from pytensor.tensor import dot
19-
from pytensor.tensor.basic import atleast_2d, concatenate, zeros_like
18+
from pytensor.tensor.basic import (
19+
atleast_2d,
20+
concatenate,
21+
tensor,
22+
tensor_from_scalar,
23+
zeros_like,
24+
)
25+
from pytensor.tensor.math import dot
2026
from pytensor.tensor.slinalg import solve
2127
from pytensor.tensor.variable import TensorVariable
2228

@@ -223,9 +229,31 @@ def make_node(self, *inputs):
223229
input.type == inner_input.type
224230
), f"Input {input} does not match expected type {inner_input.type}"
225231

226-
return Apply(
227-
self, inputs, [self.inner_inputs[0].type(), scalar_bool("success")]
228-
)
232+
return Apply(self, inputs, [self.inner_inputs[0].type(), ps.bool("success")])
233+
234+
235+
class ScipyScalarWrapperOp(ScipyWrapperOp):
236+
def build_fn(self):
237+
"""
238+
This is overloaded because scipy converts scalar inputs to lists, changing the return type. The
239+
wrapper function logic is there to handle this.
240+
"""
241+
242+
# We have no control over the inputs to the scipy inner function for scalar_minimize. As a result,
243+
# we need to adjust the graph to work with what scipy will be passing into the inner function --
244+
# always scalar, and always float64
245+
x, *args = self.inner_inputs
246+
new_root_x = ps.float64(name="x_scalar")
247+
new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype))
248+
249+
new_outputs = graph_replace(self.inner_outputs, {x: new_x})
250+
251+
self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True)
252+
253+
# Do this reassignment to see the compiled graph in the dprint
254+
# self.fgraph = fn.maker.fgraph
255+
256+
self._fn_wrapped = LRUCache1(fn)
229257

230258

231259
def scalar_implict_optimization_grads(
@@ -327,7 +355,7 @@ def implict_optimization_grads(
327355
return grad_wrt_args
328356

329357

330-
class MinimizeScalarOp(ScipyWrapperOp):
358+
class MinimizeScalarOp(ScipyScalarWrapperOp):
331359
__props__ = ("method",)
332360

333361
def __init__(
@@ -338,6 +366,14 @@ def __init__(
338366
method: str = "brent",
339367
optimizer_kwargs: dict | None = None,
340368
):
369+
if not x.ndim == 0:
370+
raise ValueError(
371+
"The variable `x` must be a scalar (0-dimensional) tensor for minimize_scalar."
372+
)
373+
if not objective.ndim == 0:
374+
raise ValueError(
375+
"The objective function must be a scalar (0-dimensional) tensor for minimize_scalar."
376+
)
341377
self.fgraph = FunctionGraph([x, *args], [objective])
342378

343379
self.method = method
@@ -351,7 +387,7 @@ def perform(self, node, inputs, outputs):
351387

352388
# minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
353389
# the args of the objective function), but it is not used in the optimization.
354-
_, *args = inputs
390+
x0, *args = inputs
355391

356392
res = scipy_minimize_scalar(
357393
fun=f.value,
@@ -360,7 +396,7 @@ def perform(self, node, inputs, outputs):
360396
**self.optimizer_kwargs,
361397
)
362398

363-
outputs[0][0] = np.array(res.x)
399+
outputs[0][0] = np.array(res.x, dtype=x0.dtype)
364400
outputs[1][0] = np.bool_(res.success)
365401

366402
def L_op(self, inputs, outputs, output_grads):
@@ -423,6 +459,15 @@ def __init__(
423459
hessp: bool = False,
424460
optimizer_kwargs: dict | None = None,
425461
):
462+
if not objective.ndim == 0:
463+
raise ValueError(
464+
"The objective function must be a scalar (0-dimensional) tensor for minimize."
465+
)
466+
if not isinstance(x, Variable) and x not in ancestors([objective]):
467+
raise ValueError(
468+
"The variable `x` must be an input to the computational graph of the objective function."
469+
)
470+
426471
self.fgraph = FunctionGraph([x, *args], [objective])
427472

428473
if jac:
@@ -462,7 +507,7 @@ def perform(self, node, inputs, outputs):
462507

463508
f.clear_cache()
464509

465-
outputs[0][0] = res.x
510+
outputs[0][0] = res.x.astype(x0.dtype)
466511
outputs[1][0] = np.bool_(res.success)
467512

468513
def L_op(self, inputs, outputs, output_grads):
@@ -541,7 +586,7 @@ def minimize(
541586
return minimize_op(x, *args)
542587

543588

544-
class RootScalarOp(ScipyWrapperOp):
589+
class RootScalarOp(ScipyScalarWrapperOp):
545590
__props__ = ("method", "jac", "hess")
546591

547592
def __init__(
@@ -554,6 +599,17 @@ def __init__(
554599
hess: bool = False,
555600
optimizer_kwargs=None,
556601
):
602+
if not equation.ndim == 0:
603+
raise ValueError(
604+
"The equation must be a scalar (0-dimensional) tensor for root_scalar."
605+
)
606+
if not isinstance(variables, Variable) or variables not in ancestors(
607+
[equation]
608+
):
609+
raise ValueError(
610+
"The variable `variables` must be an input to the computational graph of the equation."
611+
)
612+
557613
self.fgraph = FunctionGraph([variables, *args], [equation])
558614

559615
if jac:
@@ -673,6 +729,32 @@ def __init__(
673729
self._fn = None
674730
self._fn_wrapped = None
675731

732+
def build_fn(self):
733+
outputs = self.inner_outputs
734+
variables, *args = self.inner_inputs
735+
736+
if variables.ndim > 0:
737+
new_root_variables = variables
738+
new_outputs = outputs
739+
else:
740+
# If the user passes a scalar optimization problem to root, scipy will automatically upcast it to
741+
# a 1d array. The inner function needs to be adjusted to handle this.
742+
new_root_variables = tensor(
743+
name="variables_vector", shape=(1,), dtype=variables.type.dtype
744+
)
745+
new_variables = new_root_variables.squeeze()
746+
747+
new_outputs = graph_replace(outputs, {variables: new_variables})
748+
749+
self._fn = fn = function(
750+
[new_root_variables, *args], new_outputs, trust_input=True
751+
)
752+
753+
# Do this reassignment to see the compiled graph in the dprint
754+
# self.fgraph = fn.maker.fgraph
755+
756+
self._fn_wrapped = LRUCache1(fn)
757+
676758
def perform(self, node, inputs, outputs):
677759
f = self.fn_wrapped
678760
f.clear_cache()

0 commit comments

Comments
 (0)