Skip to content

Commit fbf4763

Browse files
Add more options to DADVI minimization
1 parent 07c6ab4 commit fbf4763

File tree

7 files changed

+301
-96
lines changed

7 files changed

+301
-96
lines changed

pymc_extras/inference/dadvi/dadvi.py

Lines changed: 93 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
apply_function_over_dataset,
1414
coords_and_dims_for_inferencedata,
1515
)
16+
from pymc.blocking import RaveledVars
1617
from pymc.util import RandomSeed, get_default_varnames
1718
from pytensor.tensor.variable import TensorVariable
1819

20+
from pymc_extras.inference.laplace_approx.idata import (
21+
add_data_to_inference_data,
22+
add_optimizer_result_to_inference_data,
23+
)
1924
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
2025
from pymc_extras.inference.laplace_approx.scipy_interface import (
21-
_compile_functions_for_scipy_optimize,
26+
scipy_optimize_funcs_from_loss,
27+
set_optimizer_function_defaults,
2228
)
2329

2430

@@ -29,64 +35,63 @@ def fit_dadvi(
2935
n_draws: int = 1000,
3036
keep_untransformed: bool = False,
3137
optimizer_method: minimize_method = "trust-ncg",
32-
use_grad: bool = True,
33-
use_hessp: bool = True,
34-
use_hess: bool = False,
38+
use_grad: bool | None = None,
39+
use_hessp: bool | None = None,
40+
use_hess: bool | None = None,
41+
gradient_backend: str = "pytensor",
42+
compile_kwargs: dict | None = None,
3543
**minimize_kwargs,
3644
) -> az.InferenceData:
3745
"""
38-
Does inference using deterministic ADVI (automatic differentiation
39-
variational inference), DADVI for short.
46+
Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
4047
41-
For full details see the paper cited in the references:
42-
https://www.jmlr.org/papers/v25/23-1015.html
48+
For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
4349
4450
Parameters
4551
----------
4652
model : pm.Model
4753
The PyMC model to be fit. If None, the current model context is used.
4854
4955
n_fixed_draws : int
50-
The number of fixed draws to use for the optimisation. More
51-
draws will result in more accurate estimates, but also
52-
increase inference time. Usually, the default of 30 is a good
53-
tradeoff.between speed and accuracy.
56+
The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
57+
also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
5458
5559
random_seed: int
56-
The random seed to use for the fixed draws. Running the optimisation
57-
twice with the same seed should arrive at the same result.
60+
The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
61+
the same result.
5862
5963
n_draws: int
6064
The number of draws to return from the variational approximation.
6165
6266
keep_untransformed: bool
63-
Whether or not to keep the unconstrained variables (such as
64-
logs of positive-constrained parameters) in the output.
67+
Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
68+
output.
6569
6670
optimizer_method: str
67-
Which optimization method to use. The function calls
68-
``scipy.optimize.minimize``, so any of the methods there can
69-
be used. The default is trust-ncg, which uses second-order
70-
information and is generally very reliable. Other methods such
71-
as L-BFGS-B might be faster but potentially more brittle and
72-
may not converge exactly to the optimum.
71+
Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
72+
can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
73+
Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
74+
the optimum.
75+
76+
gradient_backend: str
77+
Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
78+
79+
compile_kwargs: dict, optional
80+
Additional keyword arguments to pass to `pytensor.function`
7381
7482
minimize_kwargs:
75-
Additional keyword arguments to pass to the
76-
``scipy.optimize.minimize`` function. See the documentation of
83+
Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
7784
that function for details.
7885
79-
use_grad:
80-
If True, pass the gradient function to
81-
`scipy.optimize.minimize` (where it is referred to as `jac`).
86+
use_grad: bool, optional
87+
If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
8288
83-
use_hessp:
89+
use_hessp: bool, optional
8490
If True, pass the hessian vector product to `scipy.optimize.minimize`.
8591
86-
use_hess:
87-
If True, pass the hessian to `scipy.optimize.minimize`. Note that
88-
this is generally not recommended since its computation can be slow
89-
and memory-intensive if there are many parameters.
92+
use_hess: bool, optional
93+
If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
94+
computation can be slow and memory-intensive if there are many parameters.
9095
9196
Returns
9297
-------
@@ -95,16 +100,15 @@ def fit_dadvi(
95100
96101
References
97102
----------
98-
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99-
Variational Inference with a Deterministic Objective: Faster, More
100-
Accurate, and Even More Black Box. Journal of Machine Learning
101-
Research, 25(18), 1–39.
103+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
104+
Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
102105
"""
103106

104107
model = pymc.modelcontext(model) if model is None else model
105108

106109
initial_point_dict = model.initial_point()
107-
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
110+
initial_point = DictToArrayBijection.map(initial_point_dict)
111+
n_params = initial_point.data.shape[0]
108112

109113
var_params, objective = create_dadvi_graph(
110114
model,
@@ -113,31 +117,45 @@ def fit_dadvi(
113117
n_params=n_params,
114118
)
115119

116-
f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117-
objective,
118-
[var_params],
119-
compute_grad=use_grad,
120-
compute_hessp=use_hessp,
121-
compute_hess=use_hess,
120+
use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
121+
optimizer_method, use_grad, use_hess, use_hessp
122+
)
123+
124+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
125+
loss=objective,
126+
inputs=[var_params],
127+
initial_point_dict=None,
128+
use_grad=use_grad,
129+
use_hessp=use_hessp,
130+
use_hess=use_hess,
131+
gradient_backend=gradient_backend,
132+
compile_kwargs=compile_kwargs,
133+
inputs_are_flat=True,
122134
)
123135

124-
derivative_kwargs = {}
136+
dadvi_initial_point = {
137+
f"{var_name}_mu": np.zeros_like(value).ravel()
138+
for var_name, value in initial_point_dict.items()
139+
}
140+
dadvi_initial_point.update(
141+
{
142+
f"{var_name}_sigma__log": np.zeros_like(value).ravel()
143+
for var_name, value in initial_point_dict.items()
144+
}
145+
)
125146

126-
if use_grad:
127-
derivative_kwargs["jac"] = True
128-
if use_hessp:
129-
derivative_kwargs["hessp"] = f_hessp
130-
if use_hess:
131-
derivative_kwargs["hess"] = True
147+
dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
132148

133149
result = minimize(
134-
f_fused,
135-
np.zeros(2 * n_params),
150+
f=f_fused,
151+
x0=dadvi_initial_point.data,
136152
method=optimizer_method,
137-
**derivative_kwargs,
153+
hessp=f_hessp,
138154
**minimize_kwargs,
139155
)
140156

157+
raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
158+
141159
opt_var_params = result.x
142160
opt_means, opt_log_sds = np.split(opt_var_params, 2)
143161

@@ -148,9 +166,29 @@ def fit_dadvi(
148166
draws = opt_means + draws_raw * np.exp(opt_log_sds)
149167
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150168

151-
transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
169+
idata = az.InferenceData(
170+
posterior=transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
171+
)
172+
173+
var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
174+
var_name_to_model_var.update(
175+
{f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
176+
)
177+
178+
idata = add_optimizer_result_to_inference_data(
179+
idata=idata,
180+
result=result,
181+
method=optimizer_method,
182+
mu=raveled_optimized,
183+
model=model,
184+
var_name_to_model_var=var_name_to_model_var,
185+
)
186+
187+
idata = add_data_to_inference_data(
188+
idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
189+
)
152190

153-
return transformed_draws
191+
return idata
154192

155193

156194
def create_dadvi_graph(

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pymc as pm
88

99
from better_optimize import basinhopping, minimize
10-
from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method
10+
from better_optimize.constants import minimize_method
1111
from pymc.blocking import DictToArrayBijection, RaveledVars
1212
from pymc.initial_point import make_initial_point_fn
1313
from pymc.model.transform.optimization import freeze_dims_and_data
@@ -24,40 +24,12 @@
2424
from pymc_extras.inference.laplace_approx.scipy_interface import (
2525
GradientBackend,
2626
scipy_optimize_funcs_from_loss,
27+
set_optimizer_function_defaults,
2728
)
2829

2930
_log = logging.getLogger(__name__)
3031

3132

32-
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
33-
method_info = MINIMIZE_MODE_KWARGS[method].copy()
34-
35-
if use_hess and use_hessp:
36-
_log.warning(
37-
'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
38-
'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
39-
'Setting "use_hess" to False.'
40-
)
41-
use_hess = False
42-
43-
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
44-
45-
if use_hessp is not None and use_hess is None:
46-
use_hess = not use_hessp
47-
48-
elif use_hess is not None and use_hessp is None:
49-
use_hessp = not use_hess
50-
51-
elif use_hessp is None and use_hess is None:
52-
use_hessp = method_info["uses_hessp"]
53-
use_hess = method_info["uses_hess"]
54-
if use_hessp and use_hess:
55-
# If a method could use either hess or hessp, we default to using hessp
56-
use_hess = False
57-
58-
return use_grad, use_hess, use_hessp
59-
60-
6133
def get_nearest_psd(A: np.ndarray) -> np.ndarray:
6234
"""
6335
Compute the nearest positive semi-definite matrix to a given matrix.

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,15 @@ def make_default_labels(name: str, shape: tuple[int, ...]) -> list:
2222
return [list(range(dim)) for dim in shape]
2323

2424

25-
def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]:
25+
def make_unpacked_variable_names(
26+
names: list[str], model: pm.Model, var_name_to_model_var: dict[str, str] | None = None
27+
) -> list[str]:
2628
coords = model.coords
2729
initial_point = model.initial_point()
2830

31+
if var_name_to_model_var is None:
32+
var_name_to_model_var = {}
33+
2934
value_to_dim = {
3035
value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None)
3136
for value in model.value_vars
@@ -37,6 +42,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
3742

3843
unpacked_variable_names = []
3944
for name in names:
45+
name = var_name_to_model_var.get(name, name)
4046
shape = initial_point[name].shape
4147
if shape:
4248
dims = dims_dict.get(name)
@@ -258,6 +264,7 @@ def optimizer_result_to_dataset(
258264
method: minimize_method | Literal["basinhopping"],
259265
mu: RaveledVars | None = None,
260266
model: pm.Model | None = None,
267+
var_name_to_model_var: dict[str, str] | None = None,
261268
) -> xr.Dataset:
262269
"""
263270
Convert an OptimizeResult object to an xarray Dataset object.
@@ -268,6 +275,9 @@ def optimizer_result_to_dataset(
268275
The result of the optimization process.
269276
method: minimize_method or "basinhopping"
270277
The optimization method used.
278+
var_name_to_model_var: dict, optional
279+
Mapping between variables in the optimization result and the model variable names. Used when auxiliary
280+
variables were introduced, e.g. in DADVI.
271281
272282
Returns
273283
-------
@@ -279,7 +289,9 @@ def optimizer_result_to_dataset(
279289

280290
model = pm.modelcontext(model) if model is None else model
281291
variable_names, *_ = zip(*mu.point_map_info)
282-
unpacked_variable_names = make_unpacked_variable_names(variable_names, model)
292+
unpacked_variable_names = make_unpacked_variable_names(
293+
variable_names, model, var_name_to_model_var
294+
)
283295

284296
data_vars = {}
285297

@@ -368,6 +380,7 @@ def add_optimizer_result_to_inference_data(
368380
method: minimize_method | Literal["basinhopping"],
369381
mu: RaveledVars | None = None,
370382
model: pm.Model | None = None,
383+
var_name_to_model_var: dict[str, str] | None = None,
371384
) -> az.InferenceData:
372385
"""
373386
Add the optimization result to an InferenceData object.
@@ -384,13 +397,18 @@ def add_optimizer_result_to_inference_data(
384397
The MAP estimate of the model parameters.
385398
model: Model, optional
386399
A PyMC model. If None, the model is taken from the current model context.
400+
var_name_to_model_var: dict, optional
401+
Mapping between variables in the optimization result and the model variable names. Used when auxiliary
402+
variables were introduced, e.g. in DADVI.
387403
388404
Returns
389405
-------
390406
idata: az.InferenceData
391407
The provided InferenceData, with the optimization results added to the "optimizer" group.
392408
"""
393-
dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model)
409+
dataset = optimizer_result_to_dataset(
410+
result, method=method, mu=mu, model=model, var_name_to_model_var=var_name_to_model_var
411+
)
394412
idata.add_groups({"optimizer_result": dataset})
395413

396414
return idata

0 commit comments

Comments
 (0)