Skip to content

Commit 7cd407e

Browse files
author
Martin Ingram
committed
Switch to better_optimize
1 parent 7b84872 commit 7cd407e

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as pt
66
import xarray
77

8+
from better_optimize import minimize
89
from better_optimize.constants import minimize_method
910
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
1011
from pymc.backends.arviz import (
@@ -14,7 +15,6 @@
1415
)
1516
from pymc.util import RandomSeed, get_default_varnames
1617
from pytensor.tensor.variable import TensorVariable
17-
from scipy.optimize import minimize
1818

1919
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
2020
from pymc_extras.inference.laplace_approx.scipy_interface import (
@@ -31,6 +31,7 @@ def fit_dadvi(
3131
optimizer_method: minimize_method = "trust-ncg",
3232
use_grad: bool = True,
3333
use_hessp: bool = True,
34+
use_hess: bool = False,
3435
**minimize_kwargs,
3536
) -> az.InferenceData:
3637
"""
@@ -82,6 +83,11 @@ def fit_dadvi(
8283
use_hessp:
8384
If True, pass the hessian vector product to `scipy.optimize.minimize`.
8485
86+
use_hess:
87+
If True, pass the hessian to `scipy.optimize.minimize`. Note that
88+
this is generally not recommended since its computation can be slow
89+
and memory-intensive if there are many parameters.
90+
8591
Returns
8692
-------
8793
:class:`~arviz.InferenceData`
@@ -110,9 +116,9 @@ def fit_dadvi(
110116
f_fused, f_hessp = _compile_functions_for_scipy_optimize(
111117
objective,
112118
[var_params],
113-
compute_grad=True,
114-
compute_hessp=True,
115-
compute_hess=False,
119+
compute_grad=use_grad,
120+
compute_hessp=use_hessp,
121+
compute_hess=use_hess,
116122
)
117123

118124
derivative_kwargs = {}
@@ -121,6 +127,8 @@ def fit_dadvi(
121127
derivative_kwargs["jac"] = True
122128
if use_hessp:
123129
derivative_kwargs["hessp"] = f_hessp
130+
if use_hess:
131+
derivative_kwargs["hess"] = True
124132

125133
result = minimize(
126134
f_fused,

0 commit comments

Comments
 (0)