Skip to content

Commit da10fbc

Browse files
Add option to vectorize jacobian in minimize/root
1 parent 1227c4d commit da10fbc

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

pytensor/tensor/optimize.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def __init__(
484484
jac: bool = True,
485485
hess: bool = False,
486486
hessp: bool = False,
487+
use_vectorized_jac: bool = False,
487488
optimizer_kwargs: dict | None = None,
488489
):
489490
if not cast(TensorVariable, objective).ndim == 0:
@@ -496,6 +497,7 @@ def __init__(
496497
)
497498

498499
self.fgraph = FunctionGraph([x, *args], [objective])
500+
self.use_vectorized_jac = use_vectorized_jac
499501

500502
if jac:
501503
grad_wrt_x = cast(
@@ -505,7 +507,12 @@ def __init__(
505507

506508
if hess:
507509
hess_wrt_x = cast(
508-
Variable, hessian(self.fgraph.outputs[0], self.fgraph.inputs[0])
510+
Variable,
511+
jacobian(
512+
self.fgraph.outputs[-1],
513+
self.fgraph.inputs[0],
514+
vectorize=use_vectorized_jac,
515+
),
509516
)
510517
self.fgraph.add_output(hess_wrt_x)
511518

@@ -564,7 +571,7 @@ def L_op(self, inputs, outputs, output_grads):
564571
implicit_f,
565572
[inner_x, *inner_args],
566573
disconnected_inputs="ignore",
567-
vectorize=True,
574+
vectorize=self.use_vectorized_jac,
568575
)
569576
grad_wrt_args = implict_optimization_grads(
570577
df_dx=df_dx,
@@ -584,6 +591,7 @@ def minimize(
584591
method: str = "BFGS",
585592
jac: bool = True,
586593
hess: bool = False,
594+
use_vectorized_jac: bool = False,
587595
optimizer_kwargs: dict | None = None,
588596
) -> tuple[TensorVariable, TensorVariable]:
589597
"""
@@ -593,18 +601,21 @@ def minimize(
593601
----------
594602
objective : TensorVariable
595603
The objective function to minimize. This should be a pytensor variable representing a scalar value.
596-
597-
x : TensorVariable
604+
x: TensorVariable
598605
The variable with respect to which the objective function is minimized. It must be an input to the
599606
computational graph of `objective`.
600-
601-
method : str, optional
607+
method: str, optional
602608
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
603-
604-
jac : bool, optional
605-
Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609+
jac: bool, optional
610+
Whether to compute and use the gradient of the objective function with respect to x for optimization.
606611
Default is True.
607-
612+
hess: bool, optional
613+
Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614+
Default is False. Note that some methods require this, while others do not support it.
615+
use_vectorized_jac: bool, optional
616+
Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617+
scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618+
but use more memory. Default is False.
608619
optimizer_kwargs
609620
Additional keyword arguments to pass to scipy.optimize.minimize
610621
@@ -627,6 +638,7 @@ def minimize(
627638
method=method,
628639
jac=jac,
629640
hess=hess,
641+
use_vectorized_jac=use_vectorized_jac,
630642
optimizer_kwargs=optimizer_kwargs,
631643
)
632644

@@ -807,6 +819,7 @@ def __init__(
807819
method: str = "hybr",
808820
jac: bool = True,
809821
optimizer_kwargs: dict | None = None,
822+
use_vectorized_jac: bool = False,
810823
):
811824
if cast(TensorVariable, variables).ndim != cast(TensorVariable, equations).ndim:
812825
raise ValueError(
@@ -821,7 +834,9 @@ def __init__(
821834

822835
if jac:
823836
jac_wrt_x = jacobian(
824-
self.fgraph.outputs[0], self.fgraph.inputs[0], vectorize=True
837+
self.fgraph.outputs[0],
838+
self.fgraph.inputs[0],
839+
vectorize=use_vectorized_jac,
825840
)
826841
self.fgraph.add_output(atleast_2d(jac_wrt_x))
827842

@@ -928,6 +943,7 @@ def root(
928943
variables: TensorVariable,
929944
method: str = "hybr",
930945
jac: bool = True,
946+
use_vectorized_jac: bool = False,
931947
optimizer_kwargs: dict | None = None,
932948
) -> tuple[TensorVariable, TensorVariable]:
933949
"""
@@ -946,6 +962,10 @@ def root(
946962
jac : bool, optional
947963
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
948964
Default is True. Most methods require this.
965+
use_vectorized_jac: bool, optional
966+
Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967+
This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968+
Default is False.
949969
optimizer_kwargs : dict, optional
950970
Additional keyword arguments to pass to `scipy.optimize.root`.
951971
@@ -969,6 +989,7 @@ def root(
969989
method=method,
970990
jac=jac,
971991
optimizer_kwargs=optimizer_kwargs,
992+
use_vectorized_jac=use_vectorized_jac,
972993
)
973994

974995
solution, success = cast(

0 commit comments

Comments
 (0)