Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
488bd9c
Add first version of deterministic ADVI
Aug 1, 2025
f46f1cd
Update API
Aug 12, 2025
894f62b
Add a notebook example
Aug 14, 2025
a1afaf6
Merge branch 'main' into add_basic_deterministic_advi
Aug 14, 2025
637fc3b
Add to API and add a docstring
Aug 14, 2025
3e397f7
Change import in notebook
Aug 14, 2025
d954ec7
Add jax to dependencies
Aug 14, 2025
aad9f21
Add pytensor version
Aug 16, 2025
ef3d86b
Fix handling of pymc model
Aug 16, 2025
6bf92ef
Add (probably suboptimal) handling of the two backends
Aug 16, 2025
32aff46
Add transformation
Aug 18, 2025
138f8c2
Follow Ricardo's advice to simplify the transformation step
Aug 19, 2025
7073a7d
Fix naming bug
Aug 19, 2025
609aef7
Document and clean up
Aug 19, 2025
b611d51
Merge branch 'main' into add_basic_deterministic_advi
Aug 19, 2025
f17a090
Fix example
Aug 19, 2025
9ab2e1e
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 20, 2025
a8a53f3
Respond to comments
Aug 20, 2025
bdee446
Fix with pre commit checks
Aug 20, 2025
3fcafb6
Update pymc_extras/inference/deterministic_advi/dadvi.py
martiningram Aug 28, 2025
ad46b07
Implement suggestions
Aug 28, 2025
6cd0184
Rename parameter because it's duplicated otherwise
Aug 28, 2025
d648105
Rename to be consistent in use of dadvi
Aug 28, 2025
9d18f80
Rename to `optimizer_method` and drop jac=True
Aug 28, 2025
9f86d4f
Add jac=True back in since trust-ncg complained
Aug 28, 2025
3b090ca
Make hessp and jac optional
Aug 28, 2025
93cd831
Harmonize naming with existing code
Aug 28, 2025
7b84872
Fix example
Aug 29, 2025
7cd407e
Switch to `better_optimize`
Aug 29, 2025
cb070aa
Replace with pt.split
Aug 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,307 changes: 1,307 additions & 0 deletions notebooks/deterministic_advi_example.ipynb

Large diffs are not rendered by default.

13 changes: 12 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,16 @@
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
from pymc_extras.inference.deterministic_advi.api import (
fit_deterministic_advi as fit_deterministic_advi_jax,
)
from pymc_extras.inference.deterministic_advi.pytensor import fit_deterministic_advi

__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
__all__ = [
"fit",
"fit_pathfinder",
"fit_laplace",
"find_MAP",
"fit_deterministic_advi",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"fit_deterministic_advi",
"fit_dadvi",

For brevity? Just a suggestion

"fit_deterministic_advi_jax",
]
Empty file.
163 changes: 163 additions & 0 deletions pymc_extras/inference/deterministic_advi/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from typing import Callable, Dict

import numpy as np
import pymc
import arviz as az
from jax import vmap

from pymc_extras.inference.deterministic_advi.jax import build_dadvi_funs
from pymc_extras.inference.deterministic_advi.pymc_to_jax import (
get_jax_functions_from_pymc,
transform_dadvi_draws,
)
from pymc_extras.inference.deterministic_advi.core import (
find_dadvi_optimum,
get_dadvi_draws,
DADVIFuns,
)
from pymc_extras.inference.deterministic_advi.utils import opt_callback_fun


class DADVIResult:
def __init__(
self,
fixed_draws: np.ndarray,
var_params: np.ndarray,
unflattening_fun: Callable[[np.ndarray], Dict[str, np.ndarray]],
dadvi_funs: DADVIFuns,
pymc_model: pymc.Model, # TODO Check the type here
):

self.fixed_draws = fixed_draws
self.var_params = var_params
self.unflattening_fun = unflattening_fun
self.dadvi_funs = dadvi_funs
self.n_params = self.fixed_draws.shape[1]
self.pymc_model = pymc_model

def get_posterior_means(self) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with posterior means for all parameters.
"""

means = np.split(self.var_params, 2)[0]
return self.unflattening_fun(means)

def get_posterior_standard_deviations_mean_field(self) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with posterior standard deviations (not LRVB-corrected, but mean field).
"""

log_sds = np.split(self.var_params, 2)[1]
sds = np.exp(log_sds)
return self.unflattening_fun(sds)

def get_posterior_draws_mean_field(
self,
n_draws: int = 1000,
seed: int = 2,
transform_draws: bool = True,
) -> Dict[str, np.ndarray]:
"""
Returns a dictionary with draws from the posterior.
"""

np.random.seed(seed)
z = np.random.randn(n_draws, self.n_params)
dadvi_draws_flat = get_dadvi_draws(self.var_params, z)

if transform_draws:

dadvi_draws = transform_dadvi_draws(
self.pymc_model,
dadvi_draws_flat,
self.unflattening_fun,
add_chain_dim=True,
)

else:

dadvi_draws = vmap(self.unflattening_fun)(dadvi_draws_flat)

return dadvi_draws

def compute_function_on_mean_field_draws(
self,
function_to_run: Callable[[Dict], np.ndarray],
n_draws: int = 1000,
seed: int = 2,
):
dadvi_dict = self.get_posterior_draws_mean_field(n_draws, seed)

return vmap(function_to_run)(dadvi_dict)


def fit_deterministic_advi(model=None, num_fixed_draws=30, seed=2):
"""
Does inference using deterministic ADVI (automatic differentiation
variational inference).

For full details see the paper cited in the references:
https://www.jmlr.org/papers/v25/23-1015.html

Parameters
----------
model : pm.Model
The PyMC model to be fit. If None, the current model context is used.

num_fixed_draws : int
The number of fixed draws to use for the optimisation. More
draws will result in more accurate estimates, but also
increase inference time. Usually, the default of 30 is a good
tradeoff.between speed and accuracy.

seed: int
The random seed to use for the fixed draws. Running the optimisation
twice with the same seed should arrive at the same result.

Returns
-------
:class:`~arviz.InferenceData`
The inference data containing the results of the DADVI algorithm.

References
----------
Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective: Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.


"""

model = pymc.modelcontext(model) if model is None else model

np.random.seed(seed)

jax_funs = get_jax_functions_from_pymc(model)
dadvi_funs = build_dadvi_funs(jax_funs["log_posterior_fun"])

opt_callback_fun.opt_sequence = []

init_means = np.zeros(jax_funs["n_params"])
init_log_vars = np.zeros(jax_funs["n_params"]) - 3
init_var_params = np.concatenate([init_means, init_log_vars])
zs = np.random.randn(num_fixed_draws, jax_funs["n_params"])
opt = find_dadvi_optimum(
init_params=init_var_params,
zs=zs,
dadvi_funs=dadvi_funs,
verbose=True,
callback_fun=opt_callback_fun,
)

dadvi_result = DADVIResult(
fixed_draws=zs,
var_params=opt["opt_result"].x,
unflattening_fun=jax_funs["unflatten_fun"],
dadvi_funs=dadvi_funs,
pymc_model=model,
)

# Get draws and turn into arviz format expected
draws = dadvi_result.get_posterior_draws_mean_field(transform_draws=True)
az_draws = az.convert_to_inference_data(draws)

return az_draws
123 changes: 123 additions & 0 deletions pymc_extras/inference/deterministic_advi/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Core computations for DADVI.
"""

from typing import NamedTuple, Callable, Optional, Dict

from scipy.sparse.linalg import LinearOperator

import numpy as np
from pymc_extras.inference.deterministic_advi.optimization import optimize_with_hvp


class DADVIFuns(NamedTuple):
"""
This NamedTuple holds the functions required to run DADVI.

Args:
kl_est_and_grad_fun: Function of eta [variational parameters] and zs [draws].
zs should have shape [M, D], where M is number of fixed draws and D is
problem dimension. Returns a tuple whose first argument is the estimate
of the KL divergence, and the second is its gradient w.r.t. eta.
kl_est_hvp_fun: Function of eta, zs, and b, a vector to compute the hvp
with. This should return a vector -- the result of the hvp with b.
"""

kl_est_and_grad_fun: Callable[[np.ndarray, np.ndarray], np.ndarray]
kl_est_hvp_fun: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray]]


def find_dadvi_optimum(
init_params: np.ndarray,
zs: np.ndarray,
dadvi_funs: DADVIFuns,
opt_method: str = "trust-ncg",
callback_fun: Optional[Callable] = None,
verbose: bool = False,
) -> Dict:
"""
Optimises the DADVI objective.

Args:
init_params: The initial variational parameters to use. This should be a
vector of length 2D, where D is the problem dimension. The first D
entries specify the variational means, while the last D specify the log
standard deviations.
zs: The fixed draws to use in the optimisation. They must be of shape
[M, D], where D is the problem dimension and M is the number of fixed
draws.
dadvi_funs: The objective to optimise. See the definition of DADVIFuns for
more information. The kl_est_and_grad_fun is required for optimisation;
the kl_est_hvp_fun is needed only for some optimisers.
opt_method: The optimisation method to use. This must be one of the methods
listed for scipy.optimize.minimize
[https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html].
Defaults to trust-ncg, which requires the hvp to be available. For
gradient-only optimisation, L-BFGS-B generally works well.
callback_fun: If provided, this callback function is passed to
scipy.optimize.minimize. See that function's documentation for more.
verbose: If True, prints the progress of the optimisation by showing the
value and gradient norm at each iteration of the optimizer.

Returns:
A dictionary with entries "opt_result", containing the results of running
scipy.optimize.minimize, and "evaluation_count", containing the number of
times the hvp and gradient functions were called.
"""

val_and_grad_fun = lambda var_params: dadvi_funs.kl_est_and_grad_fun(var_params, zs)
hvp_fun = (
None
if dadvi_funs.kl_est_hvp_fun is None
else lambda var_params, b: dadvi_funs.kl_est_hvp_fun(var_params, zs, b)
)

opt_result, eval_count = optimize_with_hvp(
val_and_grad_fun,
hvp_fun,
init_params,
opt_method=opt_method,
callback_fun=callback_fun,
verbose=verbose,
)

to_return = {
"opt_result": opt_result,
"evaluation_count": eval_count,
}

# TODO: Here I originally had a Newton step check to assess
# convergence. Could add this back in.

return to_return


def get_dadvi_draws(var_params: np.ndarray, zs: np.ndarray) -> np.ndarray:
"""
Computes draws from the mean-field variational approximation given
variational parameters and a matrix of fixed draws.

Args:
var_params: A vector of shape 2D, the first D entries specifying the
means for the D model parameters, and the last D the log standard
deviations.
zs: A matrix of shape [N, D], containing the draws to use to sample the
variational approximation.

Returns:
A matrix of shape [N, D] containing N draws from the variational
approximation.
"""

# TODO: Could use JAX here
means, log_sds = np.split(var_params, 2)
sds = np.exp(log_sds)

draws = means.reshape(1, -1) + zs * sds.reshape(1, -1)

return draws


# TODO -- I think the functions above cover the basic functionality of
# fixed-draw ADVI. But I have not yet included the LRVB portion of the
# code, in the interest of keeping it simple. Can add later.
63 changes: 63 additions & 0 deletions pymc_extras/inference/deterministic_advi/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Callable, Dict, Tuple
import numpy as np
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad, jvp, grad
from functools import partial

from pymc_extras.inference.deterministic_advi.core import DADVIFuns
from pymc_extras.inference.deterministic_advi.optimization import count_decorator


@partial(jit, static_argnums=0)
def hvp(f, primals, tangents):
# Taken (and slightly modified) from:
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
return jvp(grad(f), (primals,), (tangents,))[1]


@jit
def _make_draws(z, mean, log_sd):

draw = z * jnp.exp(log_sd) + mean

return draw


@jit
def _calculate_entropy(log_sds):

return jnp.sum(log_sds)


def build_dadvi_funs(log_posterior_fn: Callable[[jnp.ndarray], float]) -> DADVIFuns:
"""
Builds the DADVIFuns from a log posterior density function written in JAX.
"""

def single_log_posterior_fun(cur_z, var_params):
means, log_sds = jnp.split(var_params, 2)
cur_theta = _make_draws(cur_z, means, log_sds)
return log_posterior_fn(cur_theta)

def log_posterior_expectation(zs, var_params):
single_curried = partial(single_log_posterior_fun, var_params=var_params)
log_posts = vmap(single_curried)(zs)
return jnp.mean(log_posts)

def full_kl_est(var_params, zs):
_, log_sds = jnp.split(var_params, 2)
log_posterior = log_posterior_expectation(zs, var_params)
entropy = _calculate_entropy(log_sds)
return -log_posterior - entropy

@jit
def kl_est_hvp_fun(var_params, zs, b):
rel_kl_est = partial(full_kl_est, zs=zs)
rel_hvp = lambda x, y: hvp(rel_kl_est, x, y)
return rel_hvp(var_params, b)

kl_est_and_grad_fun = jit(value_and_grad(full_kl_est))

return DADVIFuns(
kl_est_and_grad_fun=kl_est_and_grad_fun, kl_est_hvp_fun=kl_est_hvp_fun
)
Loading