7
7
8
8
import pytensor .scalar as ps
9
9
from pytensor .compile .function import function
10
- from pytensor .gradient import grad , hessian , jacobian
10
+ from pytensor .gradient import grad , jacobian
11
11
from pytensor .graph import Apply , Constant , FunctionGraph
12
12
from pytensor .graph .basic import ancestors , truncated_graph_inputs
13
13
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
@@ -483,6 +483,7 @@ def __init__(
483
483
jac : bool = True ,
484
484
hess : bool = False ,
485
485
hessp : bool = False ,
486
+ use_vectorized_jac : bool = False ,
486
487
optimizer_kwargs : dict | None = None ,
487
488
):
488
489
if not cast (TensorVariable , objective ).ndim == 0 :
@@ -495,6 +496,7 @@ def __init__(
495
496
)
496
497
497
498
self .fgraph = FunctionGraph ([x , * args ], [objective ])
499
+ self .use_vectorized_jac = use_vectorized_jac
498
500
499
501
if jac :
500
502
grad_wrt_x = cast (
@@ -504,7 +506,12 @@ def __init__(
504
506
505
507
if hess :
506
508
hess_wrt_x = cast (
507
- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
509
+ Variable ,
510
+ jacobian (
511
+ self .fgraph .outputs [- 1 ],
512
+ self .fgraph .inputs [0 ],
513
+ vectorize = use_vectorized_jac ,
514
+ ),
508
515
)
509
516
self .fgraph .add_output (hess_wrt_x )
510
517
@@ -563,7 +570,7 @@ def L_op(self, inputs, outputs, output_grads):
563
570
implicit_f ,
564
571
[inner_x , * inner_args ],
565
572
disconnected_inputs = "ignore" ,
566
- vectorize = True ,
573
+ vectorize = self . use_vectorized_jac ,
567
574
)
568
575
grad_wrt_args = implict_optimization_grads (
569
576
df_dx = df_dx ,
@@ -583,6 +590,7 @@ def minimize(
583
590
method : str = "BFGS" ,
584
591
jac : bool = True ,
585
592
hess : bool = False ,
593
+ use_vectorized_jac : bool = False ,
586
594
optimizer_kwargs : dict | None = None ,
587
595
) -> tuple [TensorVariable , TensorVariable ]:
588
596
"""
@@ -592,18 +600,21 @@ def minimize(
592
600
----------
593
601
objective : TensorVariable
594
602
The objective function to minimize. This should be a pytensor variable representing a scalar value.
595
-
596
- x : TensorVariable
603
+ x: TensorVariable
597
604
The variable with respect to which the objective function is minimized. It must be an input to the
598
605
computational graph of `objective`.
599
-
600
- method : str, optional
606
+ method: str, optional
601
607
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
602
-
603
- jac : bool, optional
604
- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
608
+ jac: bool, optional
609
+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
605
610
Default is True.
606
-
611
+ hess: bool, optional
612
+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
613
+ Default is False. Note that some methods require this, while others do not support it.
614
+ use_vectorized_jac: bool, optional
615
+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
616
+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
617
+ but use more memory. Default is False.
607
618
optimizer_kwargs
608
619
Additional keyword arguments to pass to scipy.optimize.minimize
609
620
@@ -626,6 +637,7 @@ def minimize(
626
637
method = method ,
627
638
jac = jac ,
628
639
hess = hess ,
640
+ use_vectorized_jac = use_vectorized_jac ,
629
641
optimizer_kwargs = optimizer_kwargs ,
630
642
)
631
643
@@ -806,6 +818,7 @@ def __init__(
806
818
method : str = "hybr" ,
807
819
jac : bool = True ,
808
820
optimizer_kwargs : dict | None = None ,
821
+ use_vectorized_jac : bool = False ,
809
822
):
810
823
if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
811
824
raise ValueError (
@@ -820,7 +833,9 @@ def __init__(
820
833
821
834
if jac :
822
835
jac_wrt_x = jacobian (
823
- self .fgraph .outputs [0 ], self .fgraph .inputs [0 ], vectorize = True
836
+ self .fgraph .outputs [0 ],
837
+ self .fgraph .inputs [0 ],
838
+ vectorize = use_vectorized_jac ,
824
839
)
825
840
self .fgraph .add_output (atleast_2d (jac_wrt_x ))
826
841
@@ -927,6 +942,7 @@ def root(
927
942
variables : TensorVariable ,
928
943
method : str = "hybr" ,
929
944
jac : bool = True ,
945
+ use_vectorized_jac : bool = False ,
930
946
optimizer_kwargs : dict | None = None ,
931
947
) -> tuple [TensorVariable , TensorVariable ]:
932
948
"""
@@ -945,6 +961,10 @@ def root(
945
961
jac : bool, optional
946
962
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
947
963
Default is True. Most methods require this.
964
+ use_vectorized_jac: bool, optional
965
+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
966
+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
967
+ Default is False.
948
968
optimizer_kwargs : dict, optional
949
969
Additional keyword arguments to pass to `scipy.optimize.root`.
950
970
@@ -968,6 +988,7 @@ def root(
968
988
method = method ,
969
989
jac = jac ,
970
990
optimizer_kwargs = optimizer_kwargs ,
991
+ use_vectorized_jac = use_vectorized_jac ,
971
992
)
972
993
973
994
solution , success = cast (
0 commit comments