-
Notifications
You must be signed in to change notification settings - Fork 70
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 23 commits
488bd9c
f46f1cd
894f62b
a1afaf6
637fc3b
3e397f7
d954ec7
aad9f21
ef3d86b
6bf92ef
32aff46
138f8c2
7073a7d
609aef7
b611d51
f17a090
9ab2e1e
a8a53f3
bdee446
3fcafb6
ad46b07
6cd0184
d648105
9d18f80
9f86d4f
3b090ca
93cd831
7b84872
7cd407e
cb070aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,237 @@ | ||||||
import arviz as az | ||||||
import numpy as np | ||||||
import pymc | ||||||
import pytensor | ||||||
import pytensor.tensor as pt | ||||||
import xarray | ||||||
|
||||||
from better_optimize.constants import minimize_method | ||||||
from pymc import DictToArrayBijection, Model, join_nonshared_inputs | ||||||
from pymc.backends.arviz import ( | ||||||
PointFunc, | ||||||
apply_function_over_dataset, | ||||||
coords_and_dims_for_inferencedata, | ||||||
) | ||||||
from pymc.util import RandomSeed, get_default_varnames | ||||||
from pytensor.tensor.variable import TensorVariable | ||||||
from scipy.optimize import minimize | ||||||
|
||||||
from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws | ||||||
from pymc_extras.inference.laplace_approx.scipy_interface import ( | ||||||
_compile_functions_for_scipy_optimize, | ||||||
) | ||||||
|
||||||
|
||||||
def fit_dadvi( | ||||||
model: Model | None = None, | ||||||
n_fixed_draws: int = 30, | ||||||
random_seed: RandomSeed = None, | ||||||
n_draws: int = 1000, | ||||||
keep_untransformed: bool = False, | ||||||
opt_method: minimize_method = "trust-ncg", | ||||||
**minimize_kwargs, | ||||||
) -> az.InferenceData: | ||||||
""" | ||||||
Does inference using deterministic ADVI (automatic differentiation | ||||||
variational inference), DADVI for short. | ||||||
|
||||||
For full details see the paper cited in the references: | ||||||
https://www.jmlr.org/papers/v25/23-1015.html | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
model : pm.Model | ||||||
The PyMC model to be fit. If None, the current model context is used. | ||||||
|
||||||
n_fixed_draws : int | ||||||
The number of fixed draws to use for the optimisation. More | ||||||
draws will result in more accurate estimates, but also | ||||||
increase inference time. Usually, the default of 30 is a good | ||||||
tradeoff.between speed and accuracy. | ||||||
|
||||||
random_seed: int | ||||||
The random seed to use for the fixed draws. Running the optimisation | ||||||
twice with the same seed should arrive at the same result. | ||||||
|
||||||
n_draws: int | ||||||
The number of draws to return from the variational approximation. | ||||||
|
||||||
keep_untransformed: bool | ||||||
Whether or not to keep the unconstrained variables (such as | ||||||
logs of positive-constrained parameters) in the output. | ||||||
|
||||||
opt_method: str | ||||||
Which optimization method to use. The function calls | ||||||
``scipy.optimize.minimize``, so any of the methods there can | ||||||
be used. The default is trust-ncg, which uses second-order | ||||||
information and is generally very reliable. Other methods such | ||||||
as L-BFGS-B might be faster but potentially more brittle and | ||||||
may not converge exactly to the optimum. | ||||||
|
||||||
minimize_kwargs: | ||||||
Additional keyword arguments to pass to the | ||||||
``scipy.optimize.minimize`` function. See the documentation of | ||||||
that function for details. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
:class:`~arviz.InferenceData` | ||||||
The inference data containing the results of the DADVI algorithm. | ||||||
|
||||||
References | ||||||
---------- | ||||||
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box | ||||||
Variational Inference with a Deterministic Objective: Faster, More | ||||||
Accurate, and Even More Black Box. Journal of Machine Learning | ||||||
Research, 25(18), 1–39. | ||||||
""" | ||||||
|
||||||
model = pymc.modelcontext(model) if model is None else model | ||||||
|
||||||
initial_point_dict = model.initial_point() | ||||||
n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0] | ||||||
|
||||||
var_params, objective = create_dadvi_graph( | ||||||
model, | ||||||
n_fixed_draws=n_fixed_draws, | ||||||
random_seed=random_seed, | ||||||
n_params=n_params, | ||||||
) | ||||||
|
||||||
f_fused, f_hessp = _compile_functions_for_scipy_optimize( | ||||||
objective, | ||||||
[var_params], | ||||||
compute_grad=True, | ||||||
compute_hessp=True, | ||||||
compute_hess=False, | ||||||
) | ||||||
|
||||||
result = minimize( | ||||||
f_fused, | ||||||
np.zeros(2 * n_params), | ||||||
method=opt_method, | ||||||
jac=True, | ||||||
|
jac=True, | |
jac=minimize_kwargs.get('use_jac', True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should agree on a unified name for this, because
find_MAP
ismethod
,fit_laplace
isoptimizer_method
and now this isopt_method
.My first choice is just
method
but it clashes with thepmx.fit
API (which I don't really like anyway). Thoughts?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jessegrabowski , I agree. I first matched the syntax in
find_map
but then noticed the clash with thefit
method. I'm also not a huge fan of the shared API tbh.On the other hand, I personally think
method
is maybe a bit vague anyhow. I'd be OK to go withoptimizer_method
, then at least it's consistent withfit_laplace
. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me