Skip to content

Commit 488bd9c

Browse files
author
Martin Ingram
committed
Add first version of deterministic ADVI
1 parent 24930b5 commit 488bd9c

File tree

7 files changed

+500
-0
lines changed

7 files changed

+500
-0
lines changed

pymc_extras/inference/deterministic_advi/__init__.py

Whitespace-only changes.
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import Callable, Dict
2+
3+
import numpy as np
4+
import pymc
5+
import arviz as az
6+
from jax import vmap
7+
8+
from pymc_extras.inference.deterministic_advi.jax import build_dadvi_funs
9+
from pymc_extras.inference.deterministic_advi.pymc_to_jax import (
10+
get_jax_functions_from_pymc,
11+
transform_dadvi_draws,
12+
)
13+
from pymc_extras.inference.deterministic_advi.core import (
14+
find_dadvi_optimum,
15+
get_dadvi_draws,
16+
DADVIFuns,
17+
)
18+
from pymc_extras.inference.deterministic_advi.utils import opt_callback_fun
19+
20+
21+
class DADVIResult:
22+
def __init__(
23+
self,
24+
fixed_draws: np.ndarray,
25+
var_params: np.ndarray,
26+
unflattening_fun: Callable[[np.ndarray], Dict[str, np.ndarray]],
27+
dadvi_funs: DADVIFuns,
28+
pymc_model: pymc.Model, # TODO Check the type here
29+
):
30+
31+
self.fixed_draws = fixed_draws
32+
self.var_params = var_params
33+
self.unflattening_fun = unflattening_fun
34+
self.dadvi_funs = dadvi_funs
35+
self.n_params = self.fixed_draws.shape[1]
36+
self.pymc_model = pymc_model
37+
38+
def get_posterior_means(self) -> Dict[str, np.ndarray]:
39+
"""
40+
Returns a dictionary with posterior means for all parameters.
41+
"""
42+
43+
means = np.split(self.var_params, 2)[0]
44+
return self.unflattening_fun(means)
45+
46+
def get_posterior_standard_deviations_mean_field(self) -> Dict[str, np.ndarray]:
47+
"""
48+
Returns a dictionary with posterior standard deviations (not LRVB-corrected, but mean field).
49+
"""
50+
51+
log_sds = np.split(self.var_params, 2)[1]
52+
sds = np.exp(log_sds)
53+
return self.unflattening_fun(sds)
54+
55+
def get_posterior_draws_mean_field(
56+
self,
57+
n_draws: int = 1000,
58+
seed: int = 2,
59+
transform_draws: bool = True,
60+
) -> Dict[str, np.ndarray]:
61+
"""
62+
Returns a dictionary with draws from the posterior.
63+
"""
64+
65+
np.random.seed(seed)
66+
z = np.random.randn(n_draws, self.n_params)
67+
dadvi_draws_flat = get_dadvi_draws(self.var_params, z)
68+
69+
if transform_draws:
70+
71+
dadvi_draws = transform_dadvi_draws(
72+
self.pymc_model, dadvi_draws_flat, self.unflattening_fun
73+
)
74+
75+
else:
76+
77+
dadvi_draws = vmap(self.unflattening_fun)(dadvi_draws_flat)
78+
79+
return dadvi_draws
80+
81+
def compute_function_on_mean_field_draws(
82+
self,
83+
function_to_run: Callable[[Dict], np.ndarray],
84+
n_draws: int = 1000,
85+
seed: int = 2,
86+
):
87+
dadvi_dict = self.get_posterior_draws_mean_field(n_draws, seed)
88+
89+
return vmap(function_to_run)(dadvi_dict)
90+
91+
92+
def fit_pymc_dadvi_with_jax(pymc_model, num_fixed_draws=30, seed=2):
93+
np.random.seed(seed)
94+
95+
jax_funs = get_jax_functions_from_pymc(pymc_model)
96+
dadvi_funs = build_dadvi_funs(jax_funs["log_posterior_fun"])
97+
98+
opt_callback_fun.opt_sequence = []
99+
100+
init_means = np.zeros(jax_funs["n_params"])
101+
init_log_vars = np.zeros(jax_funs["n_params"]) - 3
102+
init_var_params = np.concatenate([init_means, init_log_vars])
103+
zs = np.random.randn(num_fixed_draws, jax_funs["n_params"])
104+
opt = find_dadvi_optimum(
105+
init_params=init_var_params,
106+
zs=zs,
107+
dadvi_funs=dadvi_funs,
108+
verbose=True,
109+
callback_fun=opt_callback_fun,
110+
)
111+
112+
dadvi_result = DADVIResult(
113+
fixed_draws=zs,
114+
var_params=opt["opt_result"].x,
115+
unflattening_fun=jax_funs["unflatten_fun"],
116+
dadvi_funs=dadvi_funs,
117+
pymc_model=pymc_model,
118+
)
119+
120+
# Get draws and turn into arviz format expected
121+
draws = dadvi_result.get_posterior_draws_mean_field(transform_draws=True)
122+
123+
az_draws = az.convert_to_inference_data(draws)
124+
125+
return az_draws
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Core computations for DADVI.
3+
"""
4+
5+
from typing import NamedTuple, Callable, Optional, Dict
6+
7+
from scipy.sparse.linalg import LinearOperator
8+
9+
import numpy as np
10+
from pymc_extras.inference.deterministic_advi.optimization import optimize_with_hvp
11+
12+
13+
class DADVIFuns(NamedTuple):
14+
"""
15+
This NamedTuple holds the functions required to run DADVI.
16+
17+
Args:
18+
kl_est_and_grad_fun: Function of eta [variational parameters] and zs [draws].
19+
zs should have shape [M, D], where M is number of fixed draws and D is
20+
problem dimension. Returns a tuple whose first argument is the estimate
21+
of the KL divergence, and the second is its gradient w.r.t. eta.
22+
kl_est_hvp_fun: Function of eta, zs, and b, a vector to compute the hvp
23+
with. This should return a vector -- the result of the hvp with b.
24+
"""
25+
26+
kl_est_and_grad_fun: Callable[[np.ndarray, np.ndarray], np.ndarray]
27+
kl_est_hvp_fun: Optional[Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray]]
28+
29+
30+
def find_dadvi_optimum(
31+
init_params: np.ndarray,
32+
zs: np.ndarray,
33+
dadvi_funs: DADVIFuns,
34+
opt_method: str = "trust-ncg",
35+
callback_fun: Optional[Callable] = None,
36+
verbose: bool = False,
37+
) -> Dict:
38+
"""
39+
Optimises the DADVI objective.
40+
41+
Args:
42+
init_params: The initial variational parameters to use. This should be a
43+
vector of length 2D, where D is the problem dimension. The first D
44+
entries specify the variational means, while the last D specify the log
45+
standard deviations.
46+
zs: The fixed draws to use in the optimisation. They must be of shape
47+
[M, D], where D is the problem dimension and M is the number of fixed
48+
draws.
49+
dadvi_funs: The objective to optimise. See the definition of DADVIFuns for
50+
more information. The kl_est_and_grad_fun is required for optimisation;
51+
the kl_est_hvp_fun is needed only for some optimisers.
52+
opt_method: The optimisation method to use. This must be one of the methods
53+
listed for scipy.optimize.minimize
54+
[https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html].
55+
Defaults to trust-ncg, which requires the hvp to be available. For
56+
gradient-only optimisation, L-BFGS-B generally works well.
57+
callback_fun: If provided, this callback function is passed to
58+
scipy.optimize.minimize. See that function's documentation for more.
59+
verbose: If True, prints the progress of the optimisation by showing the
60+
value and gradient norm at each iteration of the optimizer.
61+
62+
Returns:
63+
A dictionary with entries "opt_result", containing the results of running
64+
scipy.optimize.minimize, and "evaluation_count", containing the number of
65+
times the hvp and gradient functions were called.
66+
"""
67+
68+
val_and_grad_fun = lambda var_params: dadvi_funs.kl_est_and_grad_fun(var_params, zs)
69+
hvp_fun = (
70+
None
71+
if dadvi_funs.kl_est_hvp_fun is None
72+
else lambda var_params, b: dadvi_funs.kl_est_hvp_fun(var_params, zs, b)
73+
)
74+
75+
opt_result, eval_count = optimize_with_hvp(
76+
val_and_grad_fun,
77+
hvp_fun,
78+
init_params,
79+
opt_method=opt_method,
80+
callback_fun=callback_fun,
81+
verbose=verbose,
82+
)
83+
84+
to_return = {
85+
"opt_result": opt_result,
86+
"evaluation_count": eval_count,
87+
}
88+
89+
# TODO: Here I originally had a Newton step check to assess
90+
# convergence. Could add this back in.
91+
92+
return to_return
93+
94+
95+
def get_dadvi_draws(var_params: np.ndarray, zs: np.ndarray) -> np.ndarray:
96+
"""
97+
Computes draws from the mean-field variational approximation given
98+
variational parameters and a matrix of fixed draws.
99+
100+
Args:
101+
var_params: A vector of shape 2D, the first D entries specifying the
102+
means for the D model parameters, and the last D the log standard
103+
deviations.
104+
zs: A matrix of shape [N, D], containing the draws to use to sample the
105+
variational approximation.
106+
107+
Returns:
108+
A matrix of shape [N, D] containing N draws from the variational
109+
approximation.
110+
"""
111+
112+
# TODO: Could use JAX here
113+
means, log_sds = np.split(var_params, 2)
114+
sds = np.exp(log_sds)
115+
116+
draws = means.reshape(1, -1) + zs * sds.reshape(1, -1)
117+
118+
return draws
119+
120+
121+
# TODO -- I think the functions above cover the basic functionality of
122+
# fixed-draw ADVI. But I have not yet included the LRVB portion of the
123+
# code, in the interest of keeping it simple. Can add later.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import Callable, Dict, Tuple
2+
import numpy as np
3+
import jax.numpy as jnp
4+
from jax import jit, vmap, value_and_grad, jvp, grad
5+
from functools import partial
6+
7+
from pymc_extras.inference.deterministic_advi.core import DADVIFuns
8+
from pymc_extras.inference.deterministic_advi.optimization import count_decorator
9+
10+
11+
@partial(jit, static_argnums=0)
12+
def hvp(f, primals, tangents):
13+
# Taken (and slightly modified) from:
14+
# https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
15+
return jvp(grad(f), (primals,), (tangents,))[1]
16+
17+
18+
@jit
19+
def _make_draws(z, mean, log_sd):
20+
21+
draw = z * jnp.exp(log_sd) + mean
22+
23+
return draw
24+
25+
26+
@jit
27+
def _calculate_entropy(log_sds):
28+
29+
return jnp.sum(log_sds)
30+
31+
32+
def build_dadvi_funs(log_posterior_fn: Callable[[jnp.ndarray], float]) -> DADVIFuns:
33+
"""
34+
Builds the DADVIFuns from a log posterior density function written in JAX.
35+
"""
36+
37+
def single_log_posterior_fun(cur_z, var_params):
38+
means, log_sds = jnp.split(var_params, 2)
39+
cur_theta = _make_draws(cur_z, means, log_sds)
40+
return log_posterior_fn(cur_theta)
41+
42+
def log_posterior_expectation(zs, var_params):
43+
single_curried = partial(single_log_posterior_fun, var_params=var_params)
44+
log_posts = vmap(single_curried)(zs)
45+
return jnp.mean(log_posts)
46+
47+
def full_kl_est(var_params, zs):
48+
_, log_sds = jnp.split(var_params, 2)
49+
log_posterior = log_posterior_expectation(zs, var_params)
50+
entropy = _calculate_entropy(log_sds)
51+
return -log_posterior - entropy
52+
53+
@jit
54+
def kl_est_hvp_fun(var_params, zs, b):
55+
rel_kl_est = partial(full_kl_est, zs=zs)
56+
rel_hvp = lambda x, y: hvp(rel_kl_est, x, y)
57+
return rel_hvp(var_params, b)
58+
59+
kl_est_and_grad_fun = jit(value_and_grad(full_kl_est))
60+
61+
return DADVIFuns(
62+
kl_est_and_grad_fun=kl_est_and_grad_fun, kl_est_hvp_fun=kl_est_hvp_fun
63+
)

0 commit comments

Comments
 (0)