13
13
apply_function_over_dataset ,
14
14
coords_and_dims_for_inferencedata ,
15
15
)
16
+ from pymc .blocking import RaveledVars
16
17
from pymc .util import RandomSeed , get_default_varnames
17
18
from pytensor .tensor .variable import TensorVariable
18
19
20
+ from pymc_extras .inference .laplace_approx .idata import (
21
+ add_data_to_inference_data ,
22
+ add_optimizer_result_to_inference_data ,
23
+ )
19
24
from pymc_extras .inference .laplace_approx .laplace import unstack_laplace_draws
20
25
from pymc_extras .inference .laplace_approx .scipy_interface import (
21
- _compile_functions_for_scipy_optimize ,
26
+ scipy_optimize_funcs_from_loss ,
27
+ set_optimizer_function_defaults ,
22
28
)
23
29
24
30
@@ -29,64 +35,63 @@ def fit_dadvi(
29
35
n_draws : int = 1000 ,
30
36
keep_untransformed : bool = False ,
31
37
optimizer_method : minimize_method = "trust-ncg" ,
32
- use_grad : bool = True ,
33
- use_hessp : bool = True ,
34
- use_hess : bool = False ,
38
+ use_grad : bool | None = None ,
39
+ use_hessp : bool | None = None ,
40
+ use_hess : bool | None = None ,
41
+ gradient_backend : str = "pytensor" ,
42
+ compile_kwargs : dict | None = None ,
35
43
** minimize_kwargs ,
36
44
) -> az .InferenceData :
37
45
"""
38
- Does inference using deterministic ADVI (automatic differentiation
39
- variational inference), DADVI for short.
46
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
40
47
41
- For full details see the paper cited in the references:
42
- https://www.jmlr.org/papers/v25/23-1015.html
48
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
43
49
44
50
Parameters
45
51
----------
46
52
model : pm.Model
47
53
The PyMC model to be fit. If None, the current model context is used.
48
54
49
55
n_fixed_draws : int
50
- The number of fixed draws to use for the optimisation. More
51
- draws will result in more accurate estimates, but also
52
- increase inference time. Usually, the default of 30 is a good
53
- tradeoff.between speed and accuracy.
56
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
57
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
54
58
55
59
random_seed: int
56
- The random seed to use for the fixed draws. Running the optimisation
57
- twice with the same seed should arrive at the same result.
60
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
61
+ the same result.
58
62
59
63
n_draws: int
60
64
The number of draws to return from the variational approximation.
61
65
62
66
keep_untransformed: bool
63
- Whether or not to keep the unconstrained variables (such as
64
- logs of positive-constrained parameters) in the output.
67
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
68
+ output.
65
69
66
70
optimizer_method: str
67
- Which optimization method to use. The function calls
68
- ``scipy.optimize.minimize``, so any of the methods there can
69
- be used. The default is trust-ncg, which uses second-order
70
- information and is generally very reliable. Other methods such
71
- as L-BFGS-B might be faster but potentially more brittle and
72
- may not converge exactly to the optimum.
71
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
72
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
73
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
74
+ the optimum.
75
+
76
+ gradient_backend: str
77
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
78
+
79
+ compile_kwargs: dict, optional
80
+ Additional keyword arguments to pass to `pytensor.function`
73
81
74
82
minimize_kwargs:
75
- Additional keyword arguments to pass to the
76
- ``scipy.optimize.minimize`` function. See the documentation of
83
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
77
84
that function for details.
78
85
79
- use_grad:
80
- If True, pass the gradient function to
81
- `scipy.optimize.minimize` (where it is referred to as `jac`).
86
+ use_grad: bool, optional
87
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
82
88
83
- use_hessp:
89
+ use_hessp: bool, optional
84
90
If True, pass the hessian vector product to `scipy.optimize.minimize`.
85
91
86
- use_hess:
87
- If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
- this is generally not recommended since its computation can be slow
89
- and memory-intensive if there are many parameters.
92
+ use_hess: bool, optional
93
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
94
+ computation can be slow and memory-intensive if there are many parameters.
90
95
91
96
Returns
92
97
-------
@@ -95,16 +100,15 @@ def fit_dadvi(
95
100
96
101
References
97
102
----------
98
- Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99
- Variational Inference with a Deterministic Objective: Faster, More
100
- Accurate, and Even More Black Box. Journal of Machine Learning
101
- Research, 25(18), 1–39.
103
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
104
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102
105
"""
103
106
104
107
model = pymc .modelcontext (model ) if model is None else model
105
108
106
109
initial_point_dict = model .initial_point ()
107
- n_params = DictToArrayBijection .map (initial_point_dict ).data .shape [0 ]
110
+ initial_point = DictToArrayBijection .map (initial_point_dict )
111
+ n_params = initial_point .data .shape [0 ]
108
112
109
113
var_params , objective = create_dadvi_graph (
110
114
model ,
@@ -113,31 +117,45 @@ def fit_dadvi(
113
117
n_params = n_params ,
114
118
)
115
119
116
- f_fused , f_hessp = _compile_functions_for_scipy_optimize (
117
- objective ,
118
- [var_params ],
119
- compute_grad = use_grad ,
120
- compute_hessp = use_hessp ,
121
- compute_hess = use_hess ,
120
+ use_grad , use_hess , use_hessp = set_optimizer_function_defaults (
121
+ optimizer_method , use_grad , use_hess , use_hessp
122
+ )
123
+
124
+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
125
+ loss = objective ,
126
+ inputs = [var_params ],
127
+ initial_point_dict = None ,
128
+ use_grad = use_grad ,
129
+ use_hessp = use_hessp ,
130
+ use_hess = use_hess ,
131
+ gradient_backend = gradient_backend ,
132
+ compile_kwargs = compile_kwargs ,
133
+ inputs_are_flat = True ,
122
134
)
123
135
124
- derivative_kwargs = {}
136
+ dadvi_initial_point = {
137
+ f"{ var_name } _mu" : np .zeros_like (value ).ravel ()
138
+ for var_name , value in initial_point_dict .items ()
139
+ }
140
+ dadvi_initial_point .update (
141
+ {
142
+ f"{ var_name } _sigma__log" : np .zeros_like (value ).ravel ()
143
+ for var_name , value in initial_point_dict .items ()
144
+ }
145
+ )
125
146
126
- if use_grad :
127
- derivative_kwargs ["jac" ] = True
128
- if use_hessp :
129
- derivative_kwargs ["hessp" ] = f_hessp
130
- if use_hess :
131
- derivative_kwargs ["hess" ] = True
147
+ dadvi_initial_point = DictToArrayBijection .map (dadvi_initial_point )
132
148
133
149
result = minimize (
134
- f_fused ,
135
- np . zeros ( 2 * n_params ) ,
150
+ f = f_fused ,
151
+ x0 = dadvi_initial_point . data ,
136
152
method = optimizer_method ,
137
- ** derivative_kwargs ,
153
+ hessp = f_hessp ,
138
154
** minimize_kwargs ,
139
155
)
140
156
157
+ raveled_optimized = RaveledVars (result .x , dadvi_initial_point .point_map_info )
158
+
141
159
opt_var_params = result .x
142
160
opt_means , opt_log_sds = np .split (opt_var_params , 2 )
143
161
@@ -148,9 +166,29 @@ def fit_dadvi(
148
166
draws = opt_means + draws_raw * np .exp (opt_log_sds )
149
167
draws_arviz = unstack_laplace_draws (draws , model , chains = 1 , draws = n_draws )
150
168
151
- transformed_draws = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
169
+ idata = az .InferenceData (
170
+ posterior = transform_draws (draws_arviz , model , keep_untransformed = keep_untransformed )
171
+ )
172
+
173
+ var_name_to_model_var = {f"{ var_name } _mu" : var_name for var_name in initial_point_dict .keys ()}
174
+ var_name_to_model_var .update (
175
+ {f"{ var_name } _sigma__log" : var_name for var_name in initial_point_dict .keys ()}
176
+ )
177
+
178
+ idata = add_optimizer_result_to_inference_data (
179
+ idata = idata ,
180
+ result = result ,
181
+ method = optimizer_method ,
182
+ mu = raveled_optimized ,
183
+ model = model ,
184
+ var_name_to_model_var = var_name_to_model_var ,
185
+ )
186
+
187
+ idata = add_data_to_inference_data (
188
+ idata = idata , progressbar = False , model = model , compile_kwargs = compile_kwargs
189
+ )
152
190
153
- return transformed_draws
191
+ return idata
154
192
155
193
156
194
def create_dadvi_graph (
0 commit comments