-
Notifications
You must be signed in to change notification settings - Fork 70
Add deterministic advi #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
zaxtax
merged 30 commits into
pymc-devs:main
from
martiningram:add_basic_deterministic_advi
Sep 3, 2025
Merged
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
488bd9c
Add first version of deterministic ADVI
f46f1cd
Update API
894f62b
Add a notebook example
a1afaf6
Merge branch 'main' into add_basic_deterministic_advi
637fc3b
Add to API and add a docstring
3e397f7
Change import in notebook
d954ec7
Add jax to dependencies
aad9f21
Add pytensor version
ef3d86b
Fix handling of pymc model
6bf92ef
Add (probably suboptimal) handling of the two backends
32aff46
Add transformation
138f8c2
Follow Ricardo's advice to simplify the transformation step
7073a7d
Fix naming bug
609aef7
Document and clean up
b611d51
Merge branch 'main' into add_basic_deterministic_advi
f17a090
Fix example
9ab2e1e
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram a8a53f3
Respond to comments
bdee446
Fix with pre commit checks
3fcafb6
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram ad46b07
Implement suggestions
6cd0184
Rename parameter because it's duplicated otherwise
d648105
Rename to be consistent in use of dadvi
9d18f80
Rename to `optimizer_method` and drop jac=True
9f86d4f
Add jac=True back in since trust-ncg complained
3b090ca
Make hessp and jac optional
93cd831
Harmonize naming with existing code
7b84872
Fix example
7cd407e
Switch to `better_optimize`
cb070aa
Replace with pt.split
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
from typing import Callable, Dict | ||
|
||
import numpy as np | ||
import pymc | ||
import arviz as az | ||
from jax import vmap | ||
|
||
from pymc_extras.inference.deterministic_advi.jax import build_dadvi_funs | ||
from pymc_extras.inference.deterministic_advi.pymc_to_jax import ( | ||
get_jax_functions_from_pymc, | ||
transform_dadvi_draws, | ||
) | ||
from pymc_extras.inference.deterministic_advi.core import ( | ||
find_dadvi_optimum, | ||
get_dadvi_draws, | ||
DADVIFuns, | ||
) | ||
from pymc_extras.inference.deterministic_advi.utils import opt_callback_fun | ||
|
||
|
||
class DADVIResult: | ||
def __init__( | ||
self, | ||
fixed_draws: np.ndarray, | ||
var_params: np.ndarray, | ||
unflattening_fun: Callable[[np.ndarray], Dict[str, np.ndarray]], | ||
dadvi_funs: DADVIFuns, | ||
pymc_model: pymc.Model, # TODO Check the type here | ||
): | ||
|
||
self.fixed_draws = fixed_draws | ||
self.var_params = var_params | ||
self.unflattening_fun = unflattening_fun | ||
self.dadvi_funs = dadvi_funs | ||
self.n_params = self.fixed_draws.shape[1] | ||
self.pymc_model = pymc_model | ||
|
||
def get_posterior_means(self) -> Dict[str, np.ndarray]: | ||
""" | ||
Returns a dictionary with posterior means for all parameters. | ||
""" | ||
|
||
means = np.split(self.var_params, 2)[0] | ||
return self.unflattening_fun(means) | ||
|
||
def get_posterior_standard_deviations_mean_field(self) -> Dict[str, np.ndarray]: | ||
""" | ||
Returns a dictionary with posterior standard deviations (not LRVB-corrected, but mean field). | ||
""" | ||
|
||
log_sds = np.split(self.var_params, 2)[1] | ||
sds = np.exp(log_sds) | ||
return self.unflattening_fun(sds) | ||
|
||
def get_posterior_draws_mean_field( | ||
self, | ||
n_draws: int = 1000, | ||
seed: int = 2, | ||
transform_draws: bool = True, | ||
) -> Dict[str, np.ndarray]: | ||
""" | ||
Returns a dictionary with draws from the posterior. | ||
""" | ||
|
||
np.random.seed(seed) | ||
z = np.random.randn(n_draws, self.n_params) | ||
dadvi_draws_flat = get_dadvi_draws(self.var_params, z) | ||
|
||
if transform_draws: | ||
|
||
dadvi_draws = transform_dadvi_draws( | ||
self.pymc_model, | ||
dadvi_draws_flat, | ||
self.unflattening_fun, | ||
add_chain_dim=True, | ||
) | ||
|
||
else: | ||
|
||
dadvi_draws = vmap(self.unflattening_fun)(dadvi_draws_flat) | ||
|
||
return dadvi_draws | ||
|
||
def compute_function_on_mean_field_draws( | ||
self, | ||
function_to_run: Callable[[Dict], np.ndarray], | ||
n_draws: int = 1000, | ||
seed: int = 2, | ||
): | ||
dadvi_dict = self.get_posterior_draws_mean_field(n_draws, seed) | ||
|
||
return vmap(function_to_run)(dadvi_dict) | ||
|
||
|
||
def fit_deterministic_advi(model=None, num_fixed_draws=30, seed=2): | ||
""" | ||
Does inference using deterministic ADVI (automatic differentiation | ||
variational inference). | ||
|
||
For full details see the paper cited in the references: | ||
https://www.jmlr.org/papers/v25/23-1015.html | ||
|
||
Parameters | ||
---------- | ||
model : pm.Model | ||
The PyMC model to be fit. If None, the current model context is used. | ||
|
||
num_fixed_draws : int | ||
The number of fixed draws to use for the optimisation. More | ||
draws will result in more accurate estimates, but also | ||
increase inference time. Usually, the default of 30 is a good | ||
tradeoff.between speed and accuracy. | ||
|
||
seed: int | ||
The random seed to use for the fixed draws. Running the optimisation | ||
twice with the same seed should arrive at the same result. | ||
|
||
Returns | ||
------- | ||
:class:`~arviz.InferenceData` | ||
The inference data containing the results of the DADVI algorithm. | ||
|
||
References | ||
---------- | ||
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective: Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39. | ||
|
||
|
||
""" | ||
|
||
model = pymc.modelcontext(model) if model is None else model | ||
|
||
np.random.seed(seed) | ||
|
||
jax_funs = get_jax_functions_from_pymc(model) | ||
dadvi_funs = build_dadvi_funs(jax_funs["log_posterior_fun"]) | ||
|
||
opt_callback_fun.opt_sequence = [] | ||
|
||
init_means = np.zeros(jax_funs["n_params"]) | ||
init_log_vars = np.zeros(jax_funs["n_params"]) - 3 | ||
init_var_params = np.concatenate([init_means, init_log_vars]) | ||
zs = np.random.randn(num_fixed_draws, jax_funs["n_params"]) | ||
opt = find_dadvi_optimum( | ||
init_params=init_var_params, | ||
zs=zs, | ||
dadvi_funs=dadvi_funs, | ||
verbose=True, | ||
callback_fun=opt_callback_fun, | ||
) | ||
|
||
dadvi_result = DADVIResult( | ||
fixed_draws=zs, | ||
var_params=opt["opt_result"].x, | ||
unflattening_fun=jax_funs["unflatten_fun"], | ||
dadvi_funs=dadvi_funs, | ||
pymc_model=model, | ||
) | ||
|
||
# Get draws and turn into arviz format expected | ||
draws = dadvi_result.get_posterior_draws_mean_field(transform_draws=True) | ||
az_draws = az.convert_to_inference_data(draws) | ||
|
||
return az_draws |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
""" | ||
Core computations for DADVI. | ||
""" | ||
|
||
from typing import NamedTuple, Callable, Optional, Dict | ||
|
||
from scipy.sparse.linalg import LinearOperator | ||
|
||
import numpy as np | ||
from pymc_extras.inference.deterministic_advi.optimization import optimize_with_hvp | ||
|
||
|
||
class DADVIFuns(NamedTuple): | ||
""" | ||
This NamedTuple holds the functions required to run DADVI. | ||
|
||
Args: | ||
kl_est_and_grad_fun: Function of eta [variational parameters] and zs [draws]. | ||
zs should have shape [M, D], where M is number of fixed draws and D is | ||
problem dimension. Returns a tuple whose first argument is the estimate | ||
of the KL divergence, and the second is its gradient w.r.t. eta. | ||
kl_est_hvp_fun: Function of eta, zs, and b, a vector to compute the hvp | ||
with. This should return a vector -- the result of the hvp with b. | ||
""" | ||
|
||
kl_est_and_grad_fun: Callable[[np.ndarray, np.ndarray], np.ndarray] | ||
kl_est_hvp_fun: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray]] | ||
|
||
|
||
def find_dadvi_optimum( | ||
init_params: np.ndarray, | ||
zs: np.ndarray, | ||
dadvi_funs: DADVIFuns, | ||
opt_method: str = "trust-ncg", | ||
callback_fun: Optional[Callable] = None, | ||
verbose: bool = False, | ||
) -> Dict: | ||
""" | ||
Optimises the DADVI objective. | ||
|
||
Args: | ||
init_params: The initial variational parameters to use. This should be a | ||
vector of length 2D, where D is the problem dimension. The first D | ||
entries specify the variational means, while the last D specify the log | ||
standard deviations. | ||
zs: The fixed draws to use in the optimisation. They must be of shape | ||
[M, D], where D is the problem dimension and M is the number of fixed | ||
draws. | ||
dadvi_funs: The objective to optimise. See the definition of DADVIFuns for | ||
more information. The kl_est_and_grad_fun is required for optimisation; | ||
the kl_est_hvp_fun is needed only for some optimisers. | ||
opt_method: The optimisation method to use. This must be one of the methods | ||
listed for scipy.optimize.minimize | ||
[https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html]. | ||
Defaults to trust-ncg, which requires the hvp to be available. For | ||
gradient-only optimisation, L-BFGS-B generally works well. | ||
callback_fun: If provided, this callback function is passed to | ||
scipy.optimize.minimize. See that function's documentation for more. | ||
verbose: If True, prints the progress of the optimisation by showing the | ||
value and gradient norm at each iteration of the optimizer. | ||
|
||
Returns: | ||
A dictionary with entries "opt_result", containing the results of running | ||
scipy.optimize.minimize, and "evaluation_count", containing the number of | ||
times the hvp and gradient functions were called. | ||
""" | ||
|
||
val_and_grad_fun = lambda var_params: dadvi_funs.kl_est_and_grad_fun(var_params, zs) | ||
hvp_fun = ( | ||
None | ||
if dadvi_funs.kl_est_hvp_fun is None | ||
else lambda var_params, b: dadvi_funs.kl_est_hvp_fun(var_params, zs, b) | ||
) | ||
|
||
opt_result, eval_count = optimize_with_hvp( | ||
val_and_grad_fun, | ||
hvp_fun, | ||
init_params, | ||
opt_method=opt_method, | ||
callback_fun=callback_fun, | ||
verbose=verbose, | ||
) | ||
|
||
to_return = { | ||
"opt_result": opt_result, | ||
"evaluation_count": eval_count, | ||
} | ||
|
||
# TODO: Here I originally had a Newton step check to assess | ||
# convergence. Could add this back in. | ||
|
||
return to_return | ||
|
||
|
||
def get_dadvi_draws(var_params: np.ndarray, zs: np.ndarray) -> np.ndarray: | ||
""" | ||
Computes draws from the mean-field variational approximation given | ||
variational parameters and a matrix of fixed draws. | ||
|
||
Args: | ||
var_params: A vector of shape 2D, the first D entries specifying the | ||
means for the D model parameters, and the last D the log standard | ||
deviations. | ||
zs: A matrix of shape [N, D], containing the draws to use to sample the | ||
variational approximation. | ||
|
||
Returns: | ||
A matrix of shape [N, D] containing N draws from the variational | ||
approximation. | ||
""" | ||
|
||
# TODO: Could use JAX here | ||
means, log_sds = np.split(var_params, 2) | ||
sds = np.exp(log_sds) | ||
|
||
draws = means.reshape(1, -1) + zs * sds.reshape(1, -1) | ||
|
||
return draws | ||
|
||
|
||
# TODO -- I think the functions above cover the basic functionality of | ||
# fixed-draw ADVI. But I have not yet included the LRVB portion of the | ||
# code, in the interest of keeping it simple. Can add later. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import Callable, Dict, Tuple | ||
import numpy as np | ||
import jax.numpy as jnp | ||
from jax import jit, vmap, value_and_grad, jvp, grad | ||
from functools import partial | ||
|
||
from pymc_extras.inference.deterministic_advi.core import DADVIFuns | ||
from pymc_extras.inference.deterministic_advi.optimization import count_decorator | ||
|
||
|
||
@partial(jit, static_argnums=0) | ||
def hvp(f, primals, tangents): | ||
# Taken (and slightly modified) from: | ||
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html | ||
return jvp(grad(f), (primals,), (tangents,))[1] | ||
|
||
|
||
@jit | ||
def _make_draws(z, mean, log_sd): | ||
|
||
draw = z * jnp.exp(log_sd) + mean | ||
|
||
return draw | ||
|
||
|
||
@jit | ||
def _calculate_entropy(log_sds): | ||
|
||
return jnp.sum(log_sds) | ||
|
||
|
||
def build_dadvi_funs(log_posterior_fn: Callable[[jnp.ndarray], float]) -> DADVIFuns: | ||
""" | ||
Builds the DADVIFuns from a log posterior density function written in JAX. | ||
""" | ||
|
||
def single_log_posterior_fun(cur_z, var_params): | ||
means, log_sds = jnp.split(var_params, 2) | ||
cur_theta = _make_draws(cur_z, means, log_sds) | ||
return log_posterior_fn(cur_theta) | ||
|
||
def log_posterior_expectation(zs, var_params): | ||
single_curried = partial(single_log_posterior_fun, var_params=var_params) | ||
log_posts = vmap(single_curried)(zs) | ||
return jnp.mean(log_posts) | ||
|
||
def full_kl_est(var_params, zs): | ||
_, log_sds = jnp.split(var_params, 2) | ||
log_posterior = log_posterior_expectation(zs, var_params) | ||
entropy = _calculate_entropy(log_sds) | ||
return -log_posterior - entropy | ||
|
||
@jit | ||
def kl_est_hvp_fun(var_params, zs, b): | ||
rel_kl_est = partial(full_kl_est, zs=zs) | ||
rel_hvp = lambda x, y: hvp(rel_kl_est, x, y) | ||
return rel_hvp(var_params, b) | ||
|
||
kl_est_and_grad_fun = jit(value_and_grad(full_kl_est)) | ||
|
||
return DADVIFuns( | ||
kl_est_and_grad_fun=kl_est_and_grad_fun, kl_est_hvp_fun=kl_est_hvp_fun | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For brevity? Just a suggestion