Skip to content

Commit 3302ea9

Browse files
Use pack and unpack in minimize and root
1 parent 93c9190 commit 3302ea9

File tree

2 files changed

+250
-98
lines changed

2 files changed

+250
-98
lines changed

pytensor/tensor/optimize.py

Lines changed: 134 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,37 @@
77
import pytensor.scalar as ps
88
from pytensor.compile.function import function
99
from pytensor.gradient import grad, grad_not_implemented, jacobian
10+
from pytensor.graph import rewrite_graph
1011
from pytensor.graph.basic import Apply, Constant
1112
from pytensor.graph.fg import FunctionGraph
1213
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
1314
from pytensor.graph.replace import graph_replace
14-
from pytensor.graph.traversal import ancestors, truncated_graph_inputs
15+
from pytensor.graph.traversal import (
16+
ancestors,
17+
explicit_graph_inputs,
18+
truncated_graph_inputs,
19+
)
1520
from pytensor.scalar import ScalarType, ScalarVariable
21+
from pytensor.tensor import as_tensor_variable
1622
from pytensor.tensor.basic import (
23+
atleast_1d,
1724
atleast_2d,
18-
concatenate,
1925
scalar_from_tensor,
2026
tensor,
2127
tensor_from_scalar,
2228
zeros_like,
2329
)
24-
from pytensor.tensor.math import dot
30+
from pytensor.tensor.math import tensordot
31+
from pytensor.tensor.reshape import pack, unpack
2532
from pytensor.tensor.slinalg import solve
2633
from pytensor.tensor.type import DenseTensorType
2734
from pytensor.tensor.variable import TensorVariable, Variable
2835

2936

37+
def get_ins(x):
38+
return list(explicit_graph_inputs(x))
39+
40+
3041
# scipy.optimize can be slow to import, and will not be used by most users
3142
# We import scipy.optimize lazily inside optimization perform methods to avoid this.
3243
optimize = None
@@ -143,36 +154,6 @@ def _find_optimization_parameters(
143154
]
144155

145156

146-
def _get_parameter_grads_from_vector(
147-
grad_wrt_args_vector: TensorVariable,
148-
x_star: TensorVariable,
149-
args: Sequence[TensorVariable | ScalarVariable],
150-
output_grad: TensorVariable,
151-
) -> list[TensorVariable | ScalarVariable]:
152-
"""
153-
Given a single concatenated vector of objective function gradients with respect to raveled optimization parameters,
154-
returns the contribution of each parameter to the total loss function, with the unraveled shape of the parameter.
155-
"""
156-
cursor = 0
157-
grad_wrt_args = []
158-
159-
for arg in args:
160-
arg_shape = arg.shape
161-
arg_size = arg_shape.prod()
162-
arg_grad = grad_wrt_args_vector[:, cursor : cursor + arg_size].reshape(
163-
(*x_star.shape, *arg_shape)
164-
)
165-
166-
grad_wrt_arg = dot(output_grad, arg_grad)
167-
if isinstance(arg.type, ScalarType):
168-
grad_wrt_arg = scalar_from_tensor(grad_wrt_arg)
169-
grad_wrt_args.append(grad_wrt_arg)
170-
171-
cursor += arg_size
172-
173-
return grad_wrt_args
174-
175-
176157
class ScipyWrapperOp(Op, HasInnerGraph):
177158
"""Shared logic for scipy optimization ops"""
178159

@@ -295,12 +276,14 @@ def scalar_implict_optimization_grads(
295276

296277

297278
def implict_optimization_grads(
298-
df_dx: TensorVariable,
299-
df_dtheta_columns: Sequence[TensorVariable],
279+
implicit_f: TensorVariable,
280+
inner_x: TensorVariable,
300281
args: Sequence[TensorVariable | ScalarVariable],
282+
inner_args: Sequence[TensorVariable | ScalarVariable],
301283
x_star: TensorVariable,
302284
output_grad: TensorVariable,
303285
fgraph: FunctionGraph,
286+
use_vectorized_jac: bool,
304287
) -> list[TensorVariable | ScalarVariable]:
305288
r"""
306289
Compute gradients of an optimization problem with respect to its parameters.
@@ -321,19 +304,15 @@ def implict_optimization_grads(
321304
322305
.. math::
323306
324-
\frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta), \theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
307+
\frac{d x^*(\theta)}{d \theta} = - \left(\frac{\partial f}{\partial x}\left(x^*(\theta),
308+
\theta\right)\right)^{-1} \frac{\partial f}{\partial \theta}\left(x^*(\theta), \theta\right)
325309
326310
Note that this method assumes `f(x_star(theta), theta) = 0`; so it is not immediately applicable to a minimization
327311
problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective
328312
function, which *is* indeed zero at the minimum.
329313
330314
Parameters
331315
----------
332-
df_dx : Variable
333-
The Jacobian of the objective function with respect to the variable `x`.
334-
df_dtheta_columns : Sequence[Variable]
335-
The Jacobians of the objective function with respect to the optimization parameters `theta`.
336-
Each column (or columns) corresponds to a different parameter. Should be returned by pytensor.gradient.jacobian.
337316
args : Sequence[Variable]
338317
The optimization parameters `theta`.
339318
x_star : Variable
@@ -343,23 +322,52 @@ def implict_optimization_grads(
343322
fgraph : FunctionGraph
344323
The function graph that contains the inputs and outputs of the optimization problem.
345324
"""
346-
df_dtheta = concatenate(
347-
[atleast_2d(jac_col, left=False) for jac_col in df_dtheta_columns],
348-
axis=-1,
325+
packed_inner_args, packed_arg_shapes, implicit_f = (
326+
_maybe_pack_input_variables_and_rewrite_objective(
327+
implicit_f,
328+
inner_args,
329+
)
349330
)
350331

351-
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
352-
353-
df_dx_star, df_dtheta_star = graph_replace(
354-
[atleast_2d(df_dx), df_dtheta], replace=replace
332+
df_dx, df_dtheta = jacobian(
333+
implicit_f,
334+
[inner_x, packed_inner_args],
335+
disconnected_inputs="ignore",
336+
vectorize=use_vectorized_jac,
355337
)
356338

357-
grad_wrt_args_vector = solve(-df_dx_star, df_dtheta_star)
358-
grad_wrt_args = _get_parameter_grads_from_vector(
359-
grad_wrt_args_vector, x_star, args, output_grad
360-
)
339+
inner_to_outer_map = dict(zip(fgraph.inputs, (x_star, *args)))
361340

362-
return grad_wrt_args
341+
df_dx_star, df_dtheta_star = graph_replace([df_dx, df_dtheta], inner_to_outer_map)
342+
grad_wrt_args_packed = solve(-atleast_2d(df_dx_star), atleast_1d(df_dtheta_star))
343+
344+
if packed_arg_shapes is not None:
345+
packed_shapes_from_outer = graph_replace(
346+
packed_arg_shapes, inner_to_outer_map, strict=False
347+
)
348+
grad_wrt_args = unpack(
349+
grad_wrt_args_packed,
350+
packed_shapes=packed_shapes_from_outer,
351+
axes=0 if not all(inp.ndim == 0 for inp in (x_star, *args)) else None,
352+
)
353+
else:
354+
# There might have been a dimension added when performing the solve. In that case, squeeze it out.
355+
if grad_wrt_args_packed.ndim > df_dtheta_star.ndim:
356+
grad_wrt_args_packed = grad_wrt_args_packed.squeeze(axis=0)
357+
grad_wrt_args = [grad_wrt_args_packed]
358+
359+
final_grads = [
360+
tensordot(output_grad, arg_grad, [[0], [0]])
361+
if arg_grad.ndim > 0 and output_grad.ndim > 0
362+
else arg_grad * output_grad
363+
for arg_grad in grad_wrt_args
364+
]
365+
final_grads = [
366+
scalar_from_tensor(g) if isinstance(arg.type, ScalarType) else g
367+
for arg, g in zip(args, final_grads)
368+
]
369+
370+
return final_grads
363371

364372

365373
class MinimizeScalarOp(ScipyScalarWrapperOp):
@@ -580,6 +588,7 @@ def perform(self, node, inputs, outputs):
580588
def L_op(self, inputs, outputs, output_grads):
581589
# TODO: Handle disconnected inputs
582590
x, *args = inputs
591+
583592
if non_supported_types := tuple(
584593
inp.type
585594
for inp in inputs
@@ -591,50 +600,76 @@ def L_op(self, inputs, outputs, output_grads):
591600
return [
592601
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
593602
]
603+
594604
x_star, _success = outputs
595605
output_grad, _ = output_grads
596606

597607
inner_x, *inner_args = self.fgraph.inputs
598608
inner_fx = self.fgraph.outputs[0]
599-
600609
implicit_f = grad(inner_fx, inner_x)
601610

602-
df_dx, *df_dtheta_columns = jacobian(
603-
implicit_f,
604-
[inner_x, *inner_args],
605-
disconnected_inputs="ignore",
606-
vectorize=self.use_vectorized_jac,
607-
)
608-
grad_wrt_args = implict_optimization_grads(
609-
df_dx=df_dx,
610-
df_dtheta_columns=df_dtheta_columns,
611+
final_grads = implict_optimization_grads(
612+
implicit_f=implicit_f,
613+
inner_x=inner_x,
614+
inner_args=inner_args,
611615
args=args,
612616
x_star=x_star,
613617
output_grad=output_grad,
614618
fgraph=self.fgraph,
619+
use_vectorized_jac=self.use_vectorized_jac,
615620
)
616621

617-
return [zeros_like(x), *grad_wrt_args]
622+
return [zeros_like(x), *final_grads]
623+
624+
625+
def _maybe_pack_input_variables_and_rewrite_objective(
626+
objective: TensorVariable,
627+
x: TensorVariable | Sequence[TensorVariable],
628+
) -> tuple[TensorVariable, list[TensorVariable] | None, TensorVariable]:
629+
packed_shapes = None
630+
631+
if not isinstance(x, Sequence):
632+
packed_input = x
633+
elif len(x) == 1:
634+
packed_input = x[0]
635+
else:
636+
packed_input, packed_shapes = pack(*x, axes=None)
637+
unpacked_output = unpack(packed_input, axes=None, packed_shapes=packed_shapes)
638+
639+
objective = graph_replace(
640+
objective,
641+
{
642+
xi: ui.astype(xi.type.dtype)
643+
if not (isinstance(xi.type, ScalarType))
644+
else scalar_from_tensor(ui.astype(xi.type.dtype))
645+
for xi, ui in zip(x, unpacked_output)
646+
},
647+
)
648+
objective = rewrite_graph(objective, include=("ShapeOpt", "canonicalize"))
649+
return packed_input, packed_shapes, objective
618650

619651

620652
def minimize(
621653
objective: TensorVariable,
622-
x: TensorVariable,
654+
x: TensorVariable | Sequence[TensorVariable],
623655
method: str = "BFGS",
624656
jac: bool = True,
625657
hess: bool = False,
626658
use_vectorized_jac: bool = False,
627659
optimizer_kwargs: dict | None = None,
628-
) -> tuple[TensorVariable, TensorVariable]:
660+
) -> (
661+
tuple[TensorVariable, TensorVariable]
662+
| tuple[tuple[TensorVariable, ...], TensorVariable]
663+
):
629664
"""
630665
Minimize a scalar objective function using scipy.optimize.minimize.
631666
632667
Parameters
633668
----------
634669
objective : TensorVariable
635670
The objective function to minimize. This should be a pytensor variable representing a scalar value.
636-
x: TensorVariable
637-
The variable with respect to which the objective function is minimized. It must be an input to the
671+
x: TensorVariable or list of TensorVariable
672+
The variable or variables with respect to which the objective function is minimized. It must be an input to the
638673
computational graph of `objective`.
639674
method: str, optional
640675
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
@@ -653,18 +688,23 @@ def minimize(
653688
654689
Returns
655690
-------
656-
solution: TensorVariable
657-
The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
691+
solution: TensorVariable or tuple of TensorVariable
692+
The optimized value of each of inputs `x` that minimizes `objective(x, *args)`. If the success flag
658693
is False, this will be the final state of the minimization routine, but not necessarily a minimum.
659694
660695
success: TensorVariable
661696
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
662697
value, based on the requested convergence criteria.
663698
"""
664-
args = _find_optimization_parameters(objective, x)
699+
objective = as_tensor_variable(objective)
700+
701+
packed_input, packed_shapes, objective = (
702+
_maybe_pack_input_variables_and_rewrite_objective(objective, x)
703+
)
704+
args = _find_optimization_parameters(objective, packed_input)
665705

666706
minimize_op = MinimizeOp(
667-
x,
707+
packed_input,
668708
*args,
669709
objective=objective,
670710
method=method,
@@ -674,7 +714,11 @@ def minimize(
674714
optimizer_kwargs=optimizer_kwargs,
675715
)
676716

677-
solution, success = minimize_op(x, *args)
717+
solution, success = minimize_op(packed_input, *args)
718+
719+
if packed_shapes is not None:
720+
solution = unpack(solution, axes=None, packed_shapes=packed_shapes)
721+
# solution = rewrite_graph(solution, include=("ShapeOpt", "canonicalize", "stabilize"))
678722

679723
return solution, success
680724

@@ -961,39 +1005,30 @@ def L_op(self, inputs, outputs, output_grads):
9611005
return [
9621006
grad_not_implemented(self, i, inp, msg) for i, inp in enumerate(inputs)
9631007
]
1008+
9641009
x_star, _ = outputs
9651010
output_grad, _ = output_grads
9661011

9671012
inner_x, *inner_args = self.fgraph.inputs
9681013
inner_fx = self.fgraph.outputs[0]
9691014

970-
df_dx = (
971-
jacobian(inner_fx, inner_x, vectorize=self.use_vectorized_jac)
972-
if not self.jac
973-
else self.fgraph.outputs[1]
974-
)
975-
df_dtheta_columns = jacobian(
976-
inner_fx,
977-
inner_args,
978-
disconnected_inputs="ignore",
979-
vectorize=self.use_vectorized_jac,
980-
)
981-
982-
grad_wrt_args = implict_optimization_grads(
983-
df_dx=df_dx,
984-
df_dtheta_columns=df_dtheta_columns,
1015+
final_grads = implict_optimization_grads(
1016+
implicit_f=inner_fx,
1017+
inner_x=inner_x,
1018+
inner_args=inner_args,
9851019
args=args,
9861020
x_star=x_star,
9871021
output_grad=output_grad,
9881022
fgraph=self.fgraph,
1023+
use_vectorized_jac=self.use_vectorized_jac,
9891024
)
9901025

991-
return [zeros_like(x), *grad_wrt_args]
1026+
return [zeros_like(x), *final_grads]
9921027

9931028

9941029
def root(
9951030
equations: TensorVariable,
996-
variables: TensorVariable,
1031+
variables: TensorVariable | Sequence[TensorVariable],
9971032
method: str = "hybr",
9981033
jac: bool = True,
9991034
use_vectorized_jac: bool = False,
@@ -1032,11 +1067,13 @@ def root(
10321067
success: TensorVariable
10331068
Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
10341069
"""
1035-
1036-
args = _find_optimization_parameters(equations, variables)
1070+
packed_variables, packed_shapes, equations = (
1071+
_maybe_pack_input_variables_and_rewrite_objective(equations, variables)
1072+
)
1073+
args = _find_optimization_parameters(equations, packed_variables)
10371074

10381075
root_op = RootOp(
1039-
variables,
1076+
packed_variables,
10401077
*args,
10411078
equations=equations,
10421079
method=method,
@@ -1045,7 +1082,10 @@ def root(
10451082
use_vectorized_jac=use_vectorized_jac,
10461083
)
10471084

1048-
solution, success = root_op(variables, *args)
1085+
solution, success = root_op(packed_variables, *args)
1086+
if packed_shapes is not None:
1087+
solution = unpack(solution, axes=None, packed_shapes=packed_shapes)
1088+
# rewrite_graph(solution, include=("ShapeOpt", "canonicalize", "stabilize"))
10491089

10501090
return solution, success
10511091

0 commit comments

Comments
 (0)