5
5
import pytensor .tensor as pt
6
6
import xarray
7
7
8
+ from better_optimize .constants import minimize_method
8
9
from pymc import DictToArrayBijection , Model , join_nonshared_inputs
9
10
from pymc .backends .arviz import (
10
11
PointFunc ,
21
22
)
22
23
23
24
24
- def fit_deterministic_advi (
25
+ def fit_dadvi (
25
26
model : Model | None = None ,
26
27
n_fixed_draws : int = 30 ,
27
28
random_seed : RandomSeed = None ,
28
29
n_draws : int = 1000 ,
29
30
keep_untransformed : bool = False ,
31
+ method : minimize_method = "trust-ncg" ,
32
+ ** minimize_kwargs ,
30
33
) -> az .InferenceData :
31
34
"""
32
35
Does inference using deterministic ADVI (automatic differentiation
33
- variational inference).
36
+ variational inference), DADVI for short .
34
37
35
38
For full details see the paper cited in the references:
36
39
https://www.jmlr.org/papers/v25/23-1015.html
@@ -57,6 +60,19 @@ def fit_deterministic_advi(
57
60
Whether or not to keep the unconstrained variables (such as
58
61
logs of positive-constrained parameters) in the output.
59
62
63
+ method: str
64
+ Which optimization method to use. The function calls
65
+ ``scipy.optimize.minimize``, so any of the methods there can
66
+ be used. The default is trust-ncg, which uses second-order
67
+ information and is generally very reliable. Other methods such
68
+ as L-BFGS-B might be faster but potentially more brittle and
69
+ may not converge exactly to the optimum.
70
+
71
+ minimize_kwargs:
72
+ Additional keyword arguments to pass to the
73
+ ``scipy.optimize.minimize`` function. See the documentation of
74
+ that function for details.
75
+
60
76
Returns
61
77
-------
62
78
:class:`~arviz.InferenceData`
@@ -90,7 +106,14 @@ def fit_deterministic_advi(
90
106
compute_hess = False ,
91
107
)
92
108
93
- result = minimize (f_fused , np .zeros (2 * n_params ), method = "trust-ncg" , jac = True , hessp = f_hessp )
109
+ result = minimize (
110
+ f_fused ,
111
+ np .zeros (2 * n_params ),
112
+ method = method ,
113
+ jac = True ,
114
+ hessp = f_hessp ,
115
+ ** minimize_kwargs ,
116
+ )
94
117
95
118
opt_var_params = result .x
96
119
opt_means , opt_log_sds = np .split (opt_var_params , 2 )
@@ -151,8 +174,7 @@ def create_dadvi_graph(
151
174
)
152
175
153
176
var_params = pt .vector (name = "eta" , shape = (2 * n_params ,))
154
-
155
- means , log_sds = pt .split (var_params , 2 )
177
+ means , log_sds = var_params [:n_params ], var_params [n_params :]
156
178
157
179
draw_matrix = pt .constant (draws )
158
180
samples = means + pt .exp (log_sds ) * draw_matrix
0 commit comments