5
5
import pytensor .tensor as pt
6
6
import xarray
7
7
8
+ from better_optimize import minimize
8
9
from better_optimize .constants import minimize_method
9
10
from pymc import DictToArrayBijection , Model , join_nonshared_inputs
10
11
from pymc .backends .arviz import (
14
15
)
15
16
from pymc .util import RandomSeed , get_default_varnames
16
17
from pytensor .tensor .variable import TensorVariable
17
- from scipy .optimize import minimize
18
18
19
19
from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
20
20
from pymc_extras .inference .laplace_approx .scipy_interface import (
@@ -31,6 +31,7 @@ def fit_dadvi(
31
31
optimizer_method : minimize_method = "trust-ncg" ,
32
32
use_grad : bool = True ,
33
33
use_hessp : bool = True ,
34
+ use_hess : bool = False ,
34
35
** minimize_kwargs ,
35
36
) -> az .InferenceData :
36
37
"""
@@ -82,6 +83,11 @@ def fit_dadvi(
82
83
use_hessp:
83
84
If True, pass the hessian vector product to `scipy.optimize.minimize`.
84
85
86
+ use_hess:
87
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
+ this is generally not recommended since its computation can be slow
89
+ and memory-intensive if there are many parameters.
90
+
85
91
Returns
86
92
-------
87
93
:class:`~arviz.InferenceData`
@@ -110,9 +116,9 @@ def fit_dadvi(
110
116
f_fused , f_hessp = _compile_functions_for_scipy_optimize (
111
117
objective ,
112
118
[var_params ],
113
- compute_grad = True ,
114
- compute_hessp = True ,
115
- compute_hess = False ,
119
+ compute_grad = use_grad ,
120
+ compute_hessp = use_hessp ,
121
+ compute_hess = use_hess ,
116
122
)
117
123
118
124
derivative_kwargs = {}
@@ -121,6 +127,8 @@ def fit_dadvi(
121
127
derivative_kwargs ["jac" ] = True
122
128
if use_hessp :
123
129
derivative_kwargs ["hessp" ] = f_hessp
130
+ if use_hess :
131
+ derivative_kwargs ["hess" ] = True
124
132
125
133
result = minimize (
126
134
f_fused ,
0 commit comments