Skip to content

Commit 5859a2e

Browse files
Add option to vectorize jacobian in minimize/root
1 parent b9a713c commit 5859a2e

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

pytensor/tensor/optimize.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytensor.scalar as ps
99
from pytensor.compile.function import function
10-
from pytensor.gradient import grad, hessian, jacobian
10+
from pytensor.gradient import grad, jacobian
1111
from pytensor.graph import Apply, Constant, FunctionGraph
1212
from pytensor.graph.basic import ancestors, truncated_graph_inputs
1313
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
@@ -483,6 +483,7 @@ def __init__(
483483
jac: bool = True,
484484
hess: bool = False,
485485
hessp: bool = False,
486+
use_vectorized_jac: bool = False,
486487
optimizer_kwargs: dict | None = None,
487488
):
488489
if not cast(TensorVariable, objective).ndim == 0:
@@ -495,6 +496,7 @@ def __init__(
495496
)
496497

497498
self.fgraph = FunctionGraph([x, *args], [objective])
499+
self.use_vectorized_jac = use_vectorized_jac
498500

499501
if jac:
500502
grad_wrt_x = cast(
@@ -504,7 +506,12 @@ def __init__(
504506

505507
if hess:
506508
hess_wrt_x = cast(
507-
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0])
509+
Variable,
510+
jacobian(
511+
self.fgraph.outputs[-1],
512+
self.fgraph.inputs[0],
513+
vectorize=use_vectorized_jac,
514+
),
508515
)
509516
self.fgraph.add_output(hess_wrt_x)
510517

@@ -563,7 +570,7 @@ def L_op(self, inputs, outputs, output_grads):
563570
implicit_f,
564571
[inner_x, *inner_args],
565572
disconnected_inputs="ignore",
566-
vectorize=True,
573+
vectorize=self.use_vectorized_jac,
567574
)
568575
grad_wrt_args = implict_optimization_grads(
569576
df_dx=df_dx,
@@ -583,6 +590,7 @@ def minimize(
583590
method: str = "BFGS",
584591
jac: bool = True,
585592
hess: bool = False,
593+
use_vectorized_jac: bool = False,
586594
optimizer_kwargs: dict | None = None,
587595
) -> tuple[TensorVariable, TensorVariable]:
588596
"""
@@ -592,18 +600,21 @@ def minimize(
592600
----------
593601
objective : TensorVariable
594602
The objective function to minimize. This should be a pytensor variable representing a scalar value.
595-
596-
x : TensorVariable
603+
x: TensorVariable
597604
The variable with respect to which the objective function is minimized. It must be an input to the
598605
computational graph of `objective`.
599-
600-
method : str, optional
606+
method: str, optional
601607
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
602-
603-
jac : bool, optional
604-
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
608+
jac: bool, optional
609+
Whether to compute and use the gradient of the objective function with respect to x for optimization.
605610
Default is True.
606-
611+
hess: bool, optional
612+
Whether to compute and use the Hessian of the objective function with respect to x for optimization.
613+
Default is False. Note that some methods require this, while others do not support it.
614+
use_vectorized_jac: bool, optional
615+
Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
616+
scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
617+
but use more memory. Default is False.
607618
optimizer_kwargs
608619
Additional keyword arguments to pass to scipy.optimize.minimize
609620
@@ -626,6 +637,7 @@ def minimize(
626637
method=method,
627638
jac=jac,
628639
hess=hess,
640+
use_vectorized_jac=use_vectorized_jac,
629641
optimizer_kwargs=optimizer_kwargs,
630642
)
631643

@@ -806,6 +818,7 @@ def __init__(
806818
method: str = "hybr",
807819
jac: bool = True,
808820
optimizer_kwargs: dict | None = None,
821+
use_vectorized_jac: bool = False,
809822
):
810823
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
811824
raise ValueError(
@@ -820,7 +833,9 @@ def __init__(
820833

821834
if jac:
822835
jac_wrt_x = jacobian(
823-
self.fgraph.outputs[0], self.fgraph.inputs[0], vectorize=True
836+
self.fgraph.outputs[0],
837+
self.fgraph.inputs[0],
838+
vectorize=use_vectorized_jac,
824839
)
825840
self.fgraph.add_output(atleast_2d(jac_wrt_x))
826841

@@ -927,6 +942,7 @@ def root(
927942
variables: TensorVariable,
928943
method: str = "hybr",
929944
jac: bool = True,
945+
use_vectorized_jac: bool = False,
930946
optimizer_kwargs: dict | None = None,
931947
) -> tuple[TensorVariable, TensorVariable]:
932948
"""
@@ -945,6 +961,10 @@ def root(
945961
jac : bool, optional
946962
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
947963
Default is True. Most methods require this.
964+
use_vectorized_jac: bool, optional
965+
Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
966+
This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
967+
Default is False.
948968
optimizer_kwargs : dict, optional
949969
Additional keyword arguments to pass to `scipy.optimize.root`.
950970
@@ -968,6 +988,7 @@ def root(
968988
method=method,
969989
jac=jac,
970990
optimizer_kwargs=optimizer_kwargs,
991+
use_vectorized_jac=use_vectorized_jac,
971992
)
972993

973994
solution, success = cast(

0 commit comments

Comments
 (0)