55import pytensor .tensor as pt
66import xarray
77
8+ from better_optimize .constants import minimize_method
89from pymc import DictToArrayBijection , Model , join_nonshared_inputs
910from pymc .backends .arviz import (
1011 PointFunc ,
2122)
2223
2324
24- def fit_deterministic_advi (
25+ def fit_dadvi (
2526 model : Model | None = None ,
2627 n_fixed_draws : int = 30 ,
2728 random_seed : RandomSeed = None ,
2829 n_draws : int = 1000 ,
2930 keep_untransformed : bool = False ,
31+ method : minimize_method = "trust-ncg" ,
32+ ** minimize_kwargs ,
3033) -> az .InferenceData :
3134 """
3235 Does inference using deterministic ADVI (automatic differentiation
33- variational inference).
36+ variational inference), DADVI for short .
3437
3538 For full details see the paper cited in the references:
3639 https://www.jmlr.org/papers/v25/23-1015.html
@@ -57,6 +60,19 @@ def fit_deterministic_advi(
5760 Whether or not to keep the unconstrained variables (such as
5861 logs of positive-constrained parameters) in the output.
5962
63+ method: str
64+ Which optimization method to use. The function calls
65+ ``scipy.optimize.minimize``, so any of the methods there can
66+ be used. The default is trust-ncg, which uses second-order
67+ information and is generally very reliable. Other methods such
68+ as L-BFGS-B might be faster but potentially more brittle and
69+ may not converge exactly to the optimum.
70+
71+ minimize_kwargs:
72+ Additional keyword arguments to pass to the
73+ ``scipy.optimize.minimize`` function. See the documentation of
74+ that function for details.
75+
6076 Returns
6177 -------
6278 :class:`~arviz.InferenceData`
@@ -90,7 +106,14 @@ def fit_deterministic_advi(
90106 compute_hess = False ,
91107 )
92108
93- result = minimize (f_fused , np .zeros (2 * n_params ), method = "trust-ncg" , jac = True , hessp = f_hessp )
109+ result = minimize (
110+ f_fused ,
111+ np .zeros (2 * n_params ),
112+ method = method ,
113+ jac = True ,
114+ hessp = f_hessp ,
115+ ** minimize_kwargs ,
116+ )
94117
95118 opt_var_params = result .x
96119 opt_means , opt_log_sds = np .split (opt_var_params , 2 )
@@ -151,8 +174,7 @@ def create_dadvi_graph(
151174 )
152175
153176 var_params = pt .vector (name = "eta" , shape = (2 * n_params ,))
154-
155- means , log_sds = pt .split (var_params , 2 )
177+ means , log_sds = var_params [:n_params ], var_params [n_params :]
156178
157179 draw_matrix = pt .constant (draws )
158180 samples = means + pt .exp (log_sds ) * draw_matrix
0 commit comments