Skip to content

Commit 24e18e8

Browse files
martiningramMartin IngramricardoV94jessegrabowski
authored
Add deterministic advi (#564)
* Add first version of deterministic ADVI * Update API * Add a notebook example * Add to API and add a docstring * Change import in notebook * Add jax to dependencies * Add pytensor version * Fix handling of pymc model * Add (probably suboptimal) handling of the two backends * Add transformation * Follow Ricardo's advice to simplify the transformation step * Fix naming bug * Document and clean up * Fix example * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Ricardo Vieira <[email protected]> * Respond to comments * Fix with pre commit checks * Update pymc_extras/inference/deterministic_advi/dadvi.py Co-authored-by: Jesse Grabowski <[email protected]> * Implement suggestions * Rename parameter because it's duplicated otherwise * Rename to be consistent in use of dadvi * Rename to `optimizer_method` and drop jac=True * Add jac=True back in since trust-ncg complained * Make hessp and jac optional * Harmonize naming with existing code * Fix example * Switch to `better_optimize` * Replace with pt.split --------- Co-authored-by: Martin Ingram <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>
1 parent d5f8f76 commit 24e18e8

File tree

5 files changed

+1249
-1
lines changed

5 files changed

+1249
-1
lines changed

notebooks/deterministic_advi_example.ipynb

Lines changed: 975 additions & 0 deletions
Large diffs are not rendered by default.

pymc_extras/inference/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from pymc_extras.inference.dadvi.dadvi import fit_dadvi
1516
from pymc_extras.inference.fit import fit
1617
from pymc_extras.inference.laplace_approx.find_map import find_MAP
1718
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1819
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1920

20-
__all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
21+
__all__ = [
22+
"find_MAP",
23+
"fit",
24+
"fit_laplace",
25+
"fit_pathfinder",
26+
"fit_dadvi",
27+
]

pymc_extras/inference/dadvi/__init__.py

Whitespace-only changes.
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import arviz as az
2+
import numpy as np
3+
import pymc
4+
import pytensor
5+
import pytensor.tensor as pt
6+
import xarray
7+
8+
from better_optimize import minimize
9+
from better_optimize.constants import minimize_method
10+
from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11+
from pymc.backends.arviz import (
12+
PointFunc,
13+
apply_function_over_dataset,
14+
coords_and_dims_for_inferencedata,
15+
)
16+
from pymc.util import RandomSeed, get_default_varnames
17+
from pytensor.tensor.variable import TensorVariable
18+
19+
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
20+
from pymc_extras.inference.laplace_approx.scipy_interface import (
21+
_compile_functions_for_scipy_optimize,
22+
)
23+
24+
25+
def fit_dadvi(
26+
model: Model | None = None,
27+
n_fixed_draws: int = 30,
28+
random_seed: RandomSeed = None,
29+
n_draws: int = 1000,
30+
keep_untransformed: bool = False,
31+
optimizer_method: minimize_method = "trust-ncg",
32+
use_grad: bool = True,
33+
use_hessp: bool = True,
34+
use_hess: bool = False,
35+
**minimize_kwargs,
36+
) -> az.InferenceData:
37+
"""
38+
Does inference using deterministic ADVI (automatic differentiation
39+
variational inference), DADVI for short.
40+
41+
For full details see the paper cited in the references:
42+
https://www.jmlr.org/papers/v25/23-1015.html
43+
44+
Parameters
45+
----------
46+
model : pm.Model
47+
The PyMC model to be fit. If None, the current model context is used.
48+
49+
n_fixed_draws : int
50+
The number of fixed draws to use for the optimisation. More
51+
draws will result in more accurate estimates, but also
52+
increase inference time. Usually, the default of 30 is a good
53+
tradeoff.between speed and accuracy.
54+
55+
random_seed: int
56+
The random seed to use for the fixed draws. Running the optimisation
57+
twice with the same seed should arrive at the same result.
58+
59+
n_draws: int
60+
The number of draws to return from the variational approximation.
61+
62+
keep_untransformed: bool
63+
Whether or not to keep the unconstrained variables (such as
64+
logs of positive-constrained parameters) in the output.
65+
66+
optimizer_method: str
67+
Which optimization method to use. The function calls
68+
``scipy.optimize.minimize``, so any of the methods there can
69+
be used. The default is trust-ncg, which uses second-order
70+
information and is generally very reliable. Other methods such
71+
as L-BFGS-B might be faster but potentially more brittle and
72+
may not converge exactly to the optimum.
73+
74+
minimize_kwargs:
75+
Additional keyword arguments to pass to the
76+
``scipy.optimize.minimize`` function. See the documentation of
77+
that function for details.
78+
79+
use_grad:
80+
If True, pass the gradient function to
81+
`scipy.optimize.minimize` (where it is referred to as `jac`).
82+
83+
use_hessp:
84+
If True, pass the hessian vector product to `scipy.optimize.minimize`.
85+
86+
use_hess:
87+
If True, pass the hessian to `scipy.optimize.minimize`. Note that
88+
this is generally not recommended since its computation can be slow
89+
and memory-intensive if there are many parameters.
90+
91+
Returns
92+
-------
93+
:class:`~arviz.InferenceData`
94+
The inference data containing the results of the DADVI algorithm.
95+
96+
References
97+
----------
98+
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99+
Variational Inference with a Deterministic Objective: Faster, More
100+
Accurate, and Even More Black Box. Journal of Machine Learning
101+
Research, 25(18), 1–39.
102+
"""
103+
104+
model = pymc.modelcontext(model) if model is None else model
105+
106+
initial_point_dict = model.initial_point()
107+
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
108+
109+
var_params, objective = create_dadvi_graph(
110+
model,
111+
n_fixed_draws=n_fixed_draws,
112+
random_seed=random_seed,
113+
n_params=n_params,
114+
)
115+
116+
f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117+
objective,
118+
[var_params],
119+
compute_grad=use_grad,
120+
compute_hessp=use_hessp,
121+
compute_hess=use_hess,
122+
)
123+
124+
derivative_kwargs = {}
125+
126+
if use_grad:
127+
derivative_kwargs["jac"] = True
128+
if use_hessp:
129+
derivative_kwargs["hessp"] = f_hessp
130+
if use_hess:
131+
derivative_kwargs["hess"] = True
132+
133+
result = minimize(
134+
f_fused,
135+
np.zeros(2 * n_params),
136+
method=optimizer_method,
137+
**derivative_kwargs,
138+
**minimize_kwargs,
139+
)
140+
141+
opt_var_params = result.x
142+
opt_means, opt_log_sds = np.split(opt_var_params, 2)
143+
144+
# Make the draws:
145+
generator = np.random.default_rng(seed=random_seed)
146+
draws_raw = generator.standard_normal(size=(n_draws, n_params))
147+
148+
draws = opt_means + draws_raw * np.exp(opt_log_sds)
149+
draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150+
151+
transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
152+
153+
return transformed_draws
154+
155+
156+
def create_dadvi_graph(
157+
model: Model,
158+
n_params: int,
159+
n_fixed_draws: int = 30,
160+
random_seed: RandomSeed = None,
161+
) -> tuple[TensorVariable, TensorVariable]:
162+
"""
163+
Sets up the DADVI graph in pytensor and returns it.
164+
165+
Parameters
166+
----------
167+
model : pm.Model
168+
The PyMC model to be fit.
169+
170+
n_params: int
171+
The total number of parameters in the model.
172+
173+
n_fixed_draws : int
174+
The number of fixed draws to use.
175+
176+
random_seed: int
177+
The random seed to use for the fixed draws.
178+
179+
Returns
180+
-------
181+
Tuple[TensorVariable, TensorVariable]
182+
A tuple whose first element contains the variational parameters,
183+
and whose second contains the DADVI objective.
184+
"""
185+
186+
# Make the fixed draws
187+
generator = np.random.default_rng(seed=random_seed)
188+
draws = generator.standard_normal(size=(n_fixed_draws, n_params))
189+
190+
inputs = model.continuous_value_vars + model.discrete_value_vars
191+
initial_point_dict = model.initial_point()
192+
logp = model.logp()
193+
194+
# Graph in terms of a flat input
195+
[logp], flat_input = join_nonshared_inputs(
196+
point=initial_point_dict, outputs=[logp], inputs=inputs
197+
)
198+
199+
var_params = pt.vector(name="eta", shape=(2 * n_params,))
200+
201+
means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
202+
203+
draw_matrix = pt.constant(draws)
204+
samples = means + pt.exp(log_sds) * draw_matrix
205+
206+
logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
207+
208+
mean_log_density = pt.mean(logp_vectorized_draws)
209+
entropy = pt.sum(log_sds)
210+
211+
objective = -mean_log_density - entropy
212+
213+
return var_params, objective
214+
215+
216+
def transform_draws(
217+
unstacked_draws: xarray.Dataset,
218+
model: Model,
219+
keep_untransformed: bool = False,
220+
):
221+
"""
222+
Transforms the unconstrained draws back into the constrained space.
223+
224+
Parameters
225+
----------
226+
unstacked_draws : xarray.Dataset
227+
The draws to constrain back into the original space.
228+
229+
model : Model
230+
The PyMC model the variables were derived from.
231+
232+
n_draws: int
233+
The number of draws to return from the variational approximation.
234+
235+
keep_untransformed: bool
236+
Whether or not to keep the unconstrained variables in the output.
237+
238+
Returns
239+
-------
240+
:class:`~arviz.InferenceData`
241+
Draws from the original constrained parameters.
242+
"""
243+
244+
filtered_var_names = model.unobserved_value_vars
245+
vars_to_sample = list(
246+
get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
247+
)
248+
fn = pytensor.function(model.value_vars, vars_to_sample)
249+
point_func = PointFunc(fn)
250+
251+
coords, dims = coords_and_dims_for_inferencedata(model)
252+
253+
transformed_result = apply_function_over_dataset(
254+
point_func,
255+
unstacked_draws,
256+
output_var_names=[x.name for x in vars_to_sample],
257+
coords=coords,
258+
dims=dims,
259+
)
260+
261+
return transformed_result

pymc_extras/inference/fit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@ def fit(method: str, **kwargs) -> az.InferenceData:
4040
from pymc_extras.inference import fit_laplace
4141

4242
return fit_laplace(**kwargs)
43+
44+
if method == "dadvi":
45+
from pymc_extras.inference import fit_dadvi
46+
47+
return fit_dadvi(**kwargs)

0 commit comments

Comments
 (0)