Skip to content

Commit 3b090ca

Browse files
author
Martin Ingram
committed
Make hessp and jac optional
1 parent 9f86d4f commit 3b090ca

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ def fit_dadvi(
2929
n_draws: int = 1000,
3030
keep_untransformed: bool = False,
3131
optimizer_method: minimize_method = "trust-ncg",
32+
use_jacobian: bool = True,
33+
use_hessp: bool = True,
3234
**minimize_kwargs,
3335
) -> az.InferenceData:
3436
"""
@@ -73,6 +75,12 @@ def fit_dadvi(
7375
``scipy.optimize.minimize`` function. See the documentation of
7476
that function for details.
7577
78+
use_jacobian:
79+
If True, pass the Jacobian function to `scipy.optimize.minimize`.
80+
81+
use_hessp:
82+
If True, pass the hessian vector product to `scipy.optimize.minimize`.
83+
7684
Returns
7785
-------
7886
:class:`~arviz.InferenceData`
@@ -106,12 +114,18 @@ def fit_dadvi(
106114
compute_hess=False,
107115
)
108116

117+
derivative_kwargs = {}
118+
119+
if use_jacobian:
120+
derivative_kwargs["jac"] = True
121+
if use_hessp:
122+
derivative_kwargs["hessp"] = f_hessp
123+
109124
result = minimize(
110125
f_fused,
111126
np.zeros(2 * n_params),
112-
jac=True,
113127
method=optimizer_method,
114-
hessp=f_hessp,
128+
**derivative_kwargs,
115129
**minimize_kwargs,
116130
)
117131

0 commit comments

Comments
 (0)