Skip to content

Add deterministic advi #564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,307 changes: 1,307 additions & 0 deletions notebooks/deterministic_advi_example.ipynb

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,16 @@
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
from pymc_extras.inference.deterministic_advi.api import (
fit_deterministic_advi as fit_deterministic_advi_jax,
)
from pymc_extras.inference.deterministic_advi.pytensor import fit_deterministic_advi

__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
__all__ = [
"fit",
"fit_pathfinder",
"fit_laplace",
"find_MAP",
"fit_deterministic_advi",
"fit_deterministic_advi_jax",
]
Empty file.
163 changes: 163 additions & 0 deletions pymc_extras/inference/deterministic_advi/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Callable, Dict

import numpy as np
import pymc
import arviz as az
from jax import vmap

from pymc_extras.inference.deterministic_advi.jax import build_dadvi_funs
from pymc_extras.inference.deterministic_advi.pymc_to_jax import (
get_jax_functions_from_pymc,
transform_dadvi_draws,
)
from pymc_extras.inference.deterministic_advi.core import (
find_dadvi_optimum,
get_dadvi_draws,
DADVIFuns,
)
from pymc_extras.inference.deterministic_advi.utils import opt_callback_fun


class DADVIResult:
def __init__(
self,
fixed_draws: np.ndarray,
var_params: np.ndarray,
unflattening_fun: Callable[[np.ndarray], Dict[str, np.ndarray]],
dadvi_funs: DADVIFuns,
pymc_model: pymc.Model, # TODO Check the type here
):

self.fixed_draws = fixed_draws
self.var_params = var_params
self.unflattening_fun = unflattening_fun
self.dadvi_funs = dadvi_funs
self.n_params = self.fixed_draws.shape[1]
self.pymc_model = pymc_model

def get_posterior_means(self) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with posterior means for all parameters.
"""

means = np.split(self.var_params, 2)[0]
return self.unflattening_fun(means)

def get_posterior_standard_deviations_mean_field(self) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with posterior standard deviations (not LRVB-corrected, but mean field).
"""

log_sds = np.split(self.var_params, 2)[1]
sds = np.exp(log_sds)
return self.unflattening_fun(sds)

def get_posterior_draws_mean_field(
self,
n_draws: int = 1000,
seed: int = 2,
transform_draws: bool = True,
) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with draws from the posterior.
"""

np.random.seed(seed)
z = np.random.randn(n_draws, self.n_params)
dadvi_draws_flat = get_dadvi_draws(self.var_params, z)

if transform_draws:

dadvi_draws = transform_dadvi_draws(
self.pymc_model,
dadvi_draws_flat,
self.unflattening_fun,
add_chain_dim=True,
)

else:

dadvi_draws = vmap(self.unflattening_fun)(dadvi_draws_flat)

return dadvi_draws

def compute_function_on_mean_field_draws(
self,
function_to_run: Callable[[Dict], np.ndarray],
n_draws: int = 1000,
seed: int = 2,
):
dadvi_dict = self.get_posterior_draws_mean_field(n_draws, seed)

return vmap(function_to_run)(dadvi_dict)


def fit_deterministic_advi(model=None, num_fixed_draws=30, seed=2):
"""
Does inference using deterministic ADVI (automatic differentiation
variational inference).

For full details see the paper cited in the references:
https://www.jmlr.org/papers/v25/23-1015.html

Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.

num_fixed_draws : int
The number of fixed draws to use for the optimisation. More
draws will result in more accurate estimates, but also
increase inference time. Usually, the default of 30 is a good
tradeoff.between speed and accuracy.

seed: int
The random seed to use for the fixed draws. Running the optimisation
twice with the same seed should arrive at the same result.

Returns
-------
:class:`~arviz.InferenceData`
The inference data containing the results of the DADVI algorithm.

References
----------
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective: Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.


"""

model = pymc.modelcontext(model) if model is None else model

np.random.seed(seed)

jax_funs = get_jax_functions_from_pymc(model)
dadvi_funs = build_dadvi_funs(jax_funs["log_posterior_fun"])

opt_callback_fun.opt_sequence = []

init_means = np.zeros(jax_funs["n_params"])
init_log_vars = np.zeros(jax_funs["n_params"]) - 3
init_var_params = np.concatenate([init_means, init_log_vars])
zs = np.random.randn(num_fixed_draws, jax_funs["n_params"])
opt = find_dadvi_optimum(
init_params=init_var_params,
zs=zs,
dadvi_funs=dadvi_funs,
verbose=True,
callback_fun=opt_callback_fun,
)

dadvi_result = DADVIResult(
fixed_draws=zs,
var_params=opt["opt_result"].x,
unflattening_fun=jax_funs["unflatten_fun"],
dadvi_funs=dadvi_funs,
pymc_model=model,
)

# Get draws and turn into arviz format expected
draws = dadvi_result.get_posterior_draws_mean_field(transform_draws=True)
az_draws = az.convert_to_inference_data(draws)

return az_draws
123 changes: 123 additions & 0 deletions pymc_extras/inference/deterministic_advi/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Core computations for DADVI.
"""

from typing import NamedTuple, Callable, Optional, Dict

from scipy.sparse.linalg import LinearOperator

import numpy as np
from pymc_extras.inference.deterministic_advi.optimization import optimize_with_hvp


class DADVIFuns(NamedTuple):
"""
This NamedTuple holds the functions required to run DADVI.

Args:
kl_est_and_grad_fun: Function of eta [variational parameters] and zs [draws].
zs should have shape [M, D], where M is number of fixed draws and D is
problem dimension. Returns a tuple whose first argument is the estimate
of the KL divergence, and the second is its gradient w.r.t. eta.
kl_est_hvp_fun: Function of eta, zs, and b, a vector to compute the hvp
with. This should return a vector -- the result of the hvp with b.
"""

kl_est_and_grad_fun: Callable[[np.ndarray, np.ndarray], np.ndarray]
kl_est_hvp_fun: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray]]


def find_dadvi_optimum(
init_params: np.ndarray,
zs: np.ndarray,
dadvi_funs: DADVIFuns,
opt_method: str = "trust-ncg",
callback_fun: Optional[Callable] = None,
verbose: bool = False,
) -> Dict:
"""
Optimises the DADVI objective.

Args:
init_params: The initial variational parameters to use. This should be a
vector of length 2D, where D is the problem dimension. The first D
entries specify the variational means, while the last D specify the log
standard deviations.
zs: The fixed draws to use in the optimisation. They must be of shape
[M, D], where D is the problem dimension and M is the number of fixed
draws.
dadvi_funs: The objective to optimise. See the definition of DADVIFuns for
more information. The kl_est_and_grad_fun is required for optimisation;
the kl_est_hvp_fun is needed only for some optimisers.
opt_method: The optimisation method to use. This must be one of the methods
listed for scipy.optimize.minimize
[https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html].
Defaults to trust-ncg, which requires the hvp to be available. For
gradient-only optimisation, L-BFGS-B generally works well.
callback_fun: If provided, this callback function is passed to
scipy.optimize.minimize. See that function's documentation for more.
verbose: If True, prints the progress of the optimisation by showing the
value and gradient norm at each iteration of the optimizer.

Returns:
A dictionary with entries "opt_result", containing the results of running
scipy.optimize.minimize, and "evaluation_count", containing the number of
times the hvp and gradient functions were called.
"""

val_and_grad_fun = lambda var_params: dadvi_funs.kl_est_and_grad_fun(var_params, zs)
hvp_fun = (
None
if dadvi_funs.kl_est_hvp_fun is None
else lambda var_params, b: dadvi_funs.kl_est_hvp_fun(var_params, zs, b)
)

opt_result, eval_count = optimize_with_hvp(
val_and_grad_fun,
hvp_fun,
init_params,
opt_method=opt_method,
callback_fun=callback_fun,
verbose=verbose,
)

to_return = {
"opt_result": opt_result,
"evaluation_count": eval_count,
}

# TODO: Here I originally had a Newton step check to assess
# convergence. Could add this back in.

return to_return


def get_dadvi_draws(var_params: np.ndarray, zs: np.ndarray) -> np.ndarray:
"""
Computes draws from the mean-field variational approximation given
variational parameters and a matrix of fixed draws.

Args:
var_params: A vector of shape 2D, the first D entries specifying the
means for the D model parameters, and the last D the log standard
deviations.
zs: A matrix of shape [N, D], containing the draws to use to sample the
variational approximation.

Returns:
A matrix of shape [N, D] containing N draws from the variational
approximation.
"""

# TODO: Could use JAX here
means, log_sds = np.split(var_params, 2)
sds = np.exp(log_sds)

draws = means.reshape(1, -1) + zs * sds.reshape(1, -1)

return draws


# TODO -- I think the functions above cover the basic functionality of
# fixed-draw ADVI. But I have not yet included the LRVB portion of the
# code, in the interest of keeping it simple. Can add later.
63 changes: 63 additions & 0 deletions pymc_extras/inference/deterministic_advi/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Callable, Dict, Tuple
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad, jvp, grad
from functools import partial

from pymc_extras.inference.deterministic_advi.core import DADVIFuns
from pymc_extras.inference.deterministic_advi.optimization import count_decorator


@partial(jit, static_argnums=0)
def hvp(f, primals, tangents):
# Taken (and slightly modified) from:
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
return jvp(grad(f), (primals,), (tangents,))[1]


@jit
def _make_draws(z, mean, log_sd):

draw = z * jnp.exp(log_sd) + mean

return draw


@jit
def _calculate_entropy(log_sds):

return jnp.sum(log_sds)


def build_dadvi_funs(log_posterior_fn: Callable[[jnp.ndarray], float]) -> DADVIFuns:
"""
Builds the DADVIFuns from a log posterior density function written in JAX.
"""

def single_log_posterior_fun(cur_z, var_params):
means, log_sds = jnp.split(var_params, 2)
cur_theta = _make_draws(cur_z, means, log_sds)
return log_posterior_fn(cur_theta)

def log_posterior_expectation(zs, var_params):
single_curried = partial(single_log_posterior_fun, var_params=var_params)
log_posts = vmap(single_curried)(zs)
return jnp.mean(log_posts)

def full_kl_est(var_params, zs):
_, log_sds = jnp.split(var_params, 2)
log_posterior = log_posterior_expectation(zs, var_params)
entropy = _calculate_entropy(log_sds)
return -log_posterior - entropy

@jit
def kl_est_hvp_fun(var_params, zs, b):
rel_kl_est = partial(full_kl_est, zs=zs)
rel_hvp = lambda x, y: hvp(rel_kl_est, x, y)
return rel_hvp(var_params, b)

kl_est_and_grad_fun = jit(value_and_grad(full_kl_est))

return DADVIFuns(
kl_est_and_grad_fun=kl_est_and_grad_fun, kl_est_hvp_fun=kl_est_hvp_fun
)
Loading