5
5
import pytensor .tensor as pt
6
6
import xarray
7
7
8
- from better_optimize import minimize
8
+ from better_optimize import basinhopping , minimize
9
9
from better_optimize .constants import minimize_method
10
10
from pymc import DictToArrayBijection , Model , join_nonshared_inputs
11
11
from pymc .backends .arviz import (
31
31
def fit_dadvi (
32
32
model : Model | None = None ,
33
33
n_fixed_draws : int = 30 ,
34
- random_seed : RandomSeed = None ,
35
34
n_draws : int = 1000 ,
36
35
include_transformed : bool = False ,
37
36
optimizer_method : minimize_method = "trust-ncg" ,
@@ -40,7 +39,9 @@ def fit_dadvi(
40
39
use_hess : bool | None = None ,
41
40
gradient_backend : str = "pytensor" ,
42
41
compile_kwargs : dict | None = None ,
43
- ** minimize_kwargs ,
42
+ random_seed : RandomSeed = None ,
43
+ progressbar : bool = True ,
44
+ ** optimizer_kwargs ,
44
45
) -> az .InferenceData :
45
46
"""
46
47
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
@@ -79,10 +80,6 @@ def fit_dadvi(
79
80
compile_kwargs: dict, optional
80
81
Additional keyword arguments to pass to `pytensor.function`
81
82
82
- minimize_kwargs:
83
- Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
84
- that function for details.
85
-
86
83
use_grad: bool, optional
87
84
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
88
85
@@ -93,6 +90,13 @@ def fit_dadvi(
93
90
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
94
91
computation can be slow and memory-intensive if there are many parameters.
95
92
93
+ progressbar: bool
94
+ Whether or not to show a progress bar during optimization. Default is True.
95
+
96
+ optimizer_kwargs:
97
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
98
+ that function for details.
99
+
96
100
Returns
97
101
-------
98
102
:class:`~arviz.InferenceData`
@@ -105,6 +109,16 @@ def fit_dadvi(
105
109
"""
106
110
107
111
model = pymc .modelcontext (model ) if model is None else model
112
+ do_basinhopping = optimizer_method == "basinhopping"
113
+ minimizer_kwargs = optimizer_kwargs .pop ("minimizer_kwargs" , {})
114
+
115
+ if do_basinhopping :
116
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
117
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
118
+ # if one isn't provided.
119
+
120
+ optimizer_method = minimizer_kwargs .pop ("method" , "L-BFGS-B" )
121
+ minimizer_kwargs ["method" ] = optimizer_method
108
122
109
123
initial_point_dict = model .initial_point ()
110
124
initial_point = DictToArrayBijection .map (initial_point_dict )
@@ -145,14 +159,34 @@ def fit_dadvi(
145
159
)
146
160
147
161
dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
148
-
149
- result = minimize (
150
- f = f_fused ,
151
- x0 = dadvi_initial_point .data ,
152
- method = optimizer_method ,
153
- hessp = f_hessp ,
154
- ** minimize_kwargs ,
155
- )
162
+ args = optimizer_kwargs .pop ("args" , ())
163
+
164
+ if do_basinhopping :
165
+ if "args" not in minimizer_kwargs :
166
+ minimizer_kwargs ["args" ] = args
167
+ if "hessp" not in minimizer_kwargs :
168
+ minimizer_kwargs ["hessp" ] = f_hessp
169
+ if "method" not in minimizer_kwargs :
170
+ minimizer_kwargs ["method" ] = optimizer_method
171
+
172
+ result = basinhopping (
173
+ func = f_fused ,
174
+ x0 = dadvi_initial_point .data ,
175
+ progressbar = progressbar ,
176
+ minimizer_kwargs = minimizer_kwargs ,
177
+ ** optimizer_kwargs ,
178
+ )
179
+
180
+ else :
181
+ result = minimize (
182
+ f = f_fused ,
183
+ x0 = dadvi_initial_point .data ,
184
+ args = args ,
185
+ method = optimizer_method ,
186
+ hessp = f_hessp ,
187
+ progressbar = progressbar ,
188
+ ** optimizer_kwargs ,
189
+ )
156
190
157
191
raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
158
192
@@ -166,7 +200,9 @@ def fit_dadvi(
166
200
draws = opt_means + draws_raw * np .exp (opt_log_sds )
167
201
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
168
202
169
- idata = dadvi_result_to_idata (draws_arviz , model , include_transformed = include_transformed )
203
+ idata = dadvi_result_to_idata (
204
+ draws_arviz , model , include_transformed = include_transformed , progressbar = progressbar
205
+ )
170
206
171
207
var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
172
208
var_name_to_model_var .update (
@@ -253,6 +289,7 @@ def dadvi_result_to_idata(
253
289
unstacked_draws : xarray .Dataset ,
254
290
model : Model ,
255
291
include_transformed : bool = False ,
292
+ progressbar : bool = True ,
256
293
):
257
294
"""
258
295
Transforms the unconstrained draws back into the constrained space.
@@ -271,6 +308,9 @@ def dadvi_result_to_idata(
271
308
include_transformed: bool
272
309
Whether or not to keep the unconstrained variables in the output.
273
310
311
+ progressbar: bool
312
+ Whether or not to show a progress bar during the transformation. Default is True.
313
+
274
314
Returns
275
315
-------
276
316
:class:`~arviz.InferenceData`
@@ -292,6 +332,7 @@ def dadvi_result_to_idata(
292
332
output_var_names = [x .name for x in vars_to_sample ],
293
333
coords = coords ,
294
334
dims = dims ,
335
+ progressbar = progressbar ,
295
336
)
296
337
297
338
constrained_names = [
0 commit comments