@@ -484,6 +484,7 @@ def __init__(
484
484
jac : bool = True ,
485
485
hess : bool = False ,
486
486
hessp : bool = False ,
487
+ use_vectorized_jac : bool = False ,
487
488
optimizer_kwargs : dict | None = None ,
488
489
):
489
490
if not cast (TensorVariable , objective ).ndim == 0 :
@@ -496,6 +497,7 @@ def __init__(
496
497
)
497
498
498
499
self .fgraph = FunctionGraph ([x , * args ], [objective ])
500
+ self .use_vectorized_jac = use_vectorized_jac
499
501
500
502
if jac :
501
503
grad_wrt_x = cast (
@@ -505,7 +507,12 @@ def __init__(
505
507
506
508
if hess :
507
509
hess_wrt_x = cast (
508
- Variable , hessian (self .fgraph .outputs [0 ], self .fgraph .inputs [0 ])
510
+ Variable ,
511
+ jacobian (
512
+ self .fgraph .outputs [- 1 ],
513
+ self .fgraph .inputs [0 ],
514
+ vectorize = use_vectorized_jac ,
515
+ ),
509
516
)
510
517
self .fgraph .add_output (hess_wrt_x )
511
518
@@ -564,7 +571,7 @@ def L_op(self, inputs, outputs, output_grads):
564
571
implicit_f ,
565
572
[inner_x , * inner_args ],
566
573
disconnected_inputs = "ignore" ,
567
- vectorize = True ,
574
+ vectorize = self . use_vectorized_jac ,
568
575
)
569
576
grad_wrt_args = implict_optimization_grads (
570
577
df_dx = df_dx ,
@@ -584,6 +591,7 @@ def minimize(
584
591
method : str = "BFGS" ,
585
592
jac : bool = True ,
586
593
hess : bool = False ,
594
+ use_vectorized_jac : bool = False ,
587
595
optimizer_kwargs : dict | None = None ,
588
596
) -> tuple [TensorVariable , TensorVariable ]:
589
597
"""
@@ -593,18 +601,21 @@ def minimize(
593
601
----------
594
602
objective : TensorVariable
595
603
The objective function to minimize. This should be a pytensor variable representing a scalar value.
596
-
597
- x : TensorVariable
604
+ x: TensorVariable
598
605
The variable with respect to which the objective function is minimized. It must be an input to the
599
606
computational graph of `objective`.
600
-
601
- method : str, optional
607
+ method: str, optional
602
608
The optimization method to use. Default is "BFGS". See scipy.optimize.minimize for other options.
603
-
604
- jac : bool, optional
605
- Whether to compute and use the gradient of teh objective function with respect to x for optimization.
609
+ jac: bool, optional
610
+ Whether to compute and use the gradient of the objective function with respect to x for optimization.
606
611
Default is True.
607
-
612
+ hess: bool, optional
613
+ Whether to compute and use the Hessian of the objective function with respect to x for optimization.
614
+ Default is False. Note that some methods require this, while others do not support it.
615
+ use_vectorized_jac: bool, optional
616
+ Whether to use a vectorized graph (vmap) to compute the jacobian (and/or hessian) matrix. If False, a
617
+ scan will be used instead. This comes down to a memory/compute trade-off. Vectorized graphs can be faster,
618
+ but use more memory. Default is False.
608
619
optimizer_kwargs
609
620
Additional keyword arguments to pass to scipy.optimize.minimize
610
621
@@ -627,6 +638,7 @@ def minimize(
627
638
method = method ,
628
639
jac = jac ,
629
640
hess = hess ,
641
+ use_vectorized_jac = use_vectorized_jac ,
630
642
optimizer_kwargs = optimizer_kwargs ,
631
643
)
632
644
@@ -807,6 +819,7 @@ def __init__(
807
819
method : str = "hybr" ,
808
820
jac : bool = True ,
809
821
optimizer_kwargs : dict | None = None ,
822
+ use_vectorized_jac : bool = False ,
810
823
):
811
824
if cast (TensorVariable , variables ).ndim != cast (TensorVariable , equations ).ndim :
812
825
raise ValueError (
@@ -821,7 +834,9 @@ def __init__(
821
834
822
835
if jac :
823
836
jac_wrt_x = jacobian (
824
- self .fgraph .outputs [0 ], self .fgraph .inputs [0 ], vectorize = True
837
+ self .fgraph .outputs [0 ],
838
+ self .fgraph .inputs [0 ],
839
+ vectorize = use_vectorized_jac ,
825
840
)
826
841
self .fgraph .add_output (atleast_2d (jac_wrt_x ))
827
842
@@ -928,6 +943,7 @@ def root(
928
943
variables : TensorVariable ,
929
944
method : str = "hybr" ,
930
945
jac : bool = True ,
946
+ use_vectorized_jac : bool = False ,
931
947
optimizer_kwargs : dict | None = None ,
932
948
) -> tuple [TensorVariable , TensorVariable ]:
933
949
"""
@@ -946,6 +962,10 @@ def root(
946
962
jac : bool, optional
947
963
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
948
964
Default is True. Most methods require this.
965
+ use_vectorized_jac: bool, optional
966
+ Whether to use a vectorized graph (vmap) to compute the jacobian matrix. If False, a scan will be used instead.
967
+ This comes down to a memory/compute trade-off. Vectorized graphs can be faster, but use more memory.
968
+ Default is False.
949
969
optimizer_kwargs : dict, optional
950
970
Additional keyword arguments to pass to `scipy.optimize.root`.
951
971
@@ -969,6 +989,7 @@ def root(
969
989
method = method ,
970
990
jac = jac ,
971
991
optimizer_kwargs = optimizer_kwargs ,
992
+ use_vectorized_jac = use_vectorized_jac ,
972
993
)
973
994
974
995
solution , success = cast (
0 commit comments