Skip to content

Commit 62377ce

Browse files
Move VI autoguide work from pymc repo (#636)
* Move stuff over from pymc * Assorted fixes * remove licensing header * Fix STL estimator * Add guide model docstring --------- Co-authored-by: jessegrabowski <jessegrabowski@gmail.com>
1 parent 12a18c2 commit 62377ce

File tree

8 files changed

+1095
-0
lines changed

8 files changed

+1095
-0
lines changed

notebooks/ADVI Guide API.ipynb

Lines changed: 681 additions & 0 deletions
Large diffs are not rendered by default.

pymc_extras/inference/advi/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from dataclasses import dataclass, field
2+
3+
import numpy as np
4+
import pytensor.tensor as pt
5+
6+
from pymc.distributions import Normal
7+
from pymc.logprob.basic import conditional_logp
8+
from pymc.model.core import Deterministic, Model
9+
from pytensor import graph_replace
10+
from pytensor.gradient import disconnected_grad
11+
from pytensor.graph.basic import Variable
12+
13+
from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes
14+
15+
16+
@dataclass(frozen=True)
17+
class AutoGuideModel:
18+
model: Model
19+
params_init_values: dict[Variable, np.ndarray]
20+
name_to_param: dict[str, Variable] = field(init=False)
21+
22+
def __post_init__(self):
23+
object.__setattr__(
24+
self,
25+
"name_to_param",
26+
{x.name: x for x in self.params_init_values.keys()},
27+
)
28+
29+
@property
30+
def params(self) -> tuple[Variable, ...]:
31+
return tuple(self.params_init_values.keys())
32+
33+
def __getitem__(self, name: str) -> Variable:
34+
return self.name_to_param[name]
35+
36+
def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable:
37+
"""Returns a graph representing the logp of the guide model, evaluated under draws from its random variables."""
38+
logp_terms = conditional_logp(
39+
{rv: rv for rv in self.model.deterministics},
40+
warn_rvs=False,
41+
)
42+
logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()])
43+
44+
if stick_the_landing:
45+
# Detach variational parameters from the gradient computation of logq
46+
repl = {p: disconnected_grad(p) for p in self.params}
47+
logq = graph_replace(logq, repl)
48+
49+
return logq
50+
51+
52+
def AutoDiagonalNormal(model: Model) -> AutoGuideModel:
53+
"""
54+
Create a guide model for ADVI with a mean-field normal distribution.
55+
56+
A guide model is a variational distribution that approximates the posterior distribution of the model's free
57+
random variables. In this case, we use a mean-field normal distribution, which assumes that the free random
58+
variables are independent and normally distributed. For details, see _[1].
59+
60+
For each free random variable in the model, we create a corresponding random variable in the guide model with a
61+
normal distribution. The mean and standard deviation of each normal distribution are parameterized by learnable
62+
parameters (loc and scale), which are initialized to small random values.
63+
64+
Parameters
65+
----------
66+
model : Model
67+
The probabilistic model for which to create the guide.
68+
69+
Returns
70+
-------
71+
guide_model : AutoGuideModel
72+
An AutoGuideModel containing the guide model and the initial values for its parameters.
73+
74+
References
75+
----------
76+
.. [1] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation
77+
Variational Inference. Journal of Machine Learning Research, 18(14):1–45, 2017.
78+
"""
79+
coords = model.coords
80+
free_rvs = model.free_RVs
81+
82+
free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs)))
83+
params_init_values = {}
84+
85+
with Model(coords=coords) as guide_model:
86+
for rv in free_rvs:
87+
loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape)
88+
scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape)
89+
# TODO: Make these customizable
90+
params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval()
91+
params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval()
92+
93+
z = Normal(
94+
f"{rv.name}_z",
95+
mu=0,
96+
sigma=1,
97+
shape=free_rv_shapes[rv],
98+
)
99+
Deterministic(
100+
rv.name,
101+
loc + pt.softplus(scale) * z,
102+
dims=model.named_vars_to_dims.get(rv.name, None),
103+
)
104+
105+
return AutoGuideModel(guide_model, params_init_values)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from pymc import Model
2+
from pytensor import graph_replace
3+
from pytensor.tensor import TensorVariable
4+
5+
from pymc_extras.inference.advi.autoguide import AutoGuideModel
6+
7+
8+
def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True):
9+
"""
10+
Compute the log probability of the model and the guide.
11+
12+
Parameters
13+
----------
14+
model : Model
15+
The probabilistic model.
16+
guide : AutoGuideModel
17+
The variational guide.
18+
stick_the_landing : bool, optional
19+
Whether to use the stick-the-landing (STL) gradient estimator, by default True.
20+
The STL estimator has lower gradient variance by removing the score function term
21+
from the gradient. When True, gradients are stopped from flowing through logq.
22+
23+
Returns
24+
-------
25+
logp : TensorVariable
26+
Log probability of the model.
27+
logq : TensorVariable
28+
Log probability of the guide.
29+
"""
30+
31+
inputs_to_guide_rvs = {
32+
model_value_var: guide.model[rv.name]
33+
for rv, model_value_var in model.rvs_to_values.items()
34+
if rv not in model.observed_RVs
35+
}
36+
37+
logp = graph_replace(model.logp(), inputs_to_guide_rvs)
38+
logq = guide.stochastic_logq(stick_the_landing=stick_the_landing)
39+
40+
return logp, logq
41+
42+
43+
def advi_objective(logp: TensorVariable, logq: TensorVariable):
44+
"""Compute the negative ELBO objective for ADVI.
45+
46+
Parameters
47+
----------
48+
logp : TensorVariable
49+
Log probability of the model.
50+
logq : TensorVariable
51+
Log probability of the guide.
52+
53+
Returns
54+
-------
55+
TensorVariable
56+
The negative ELBO.
57+
"""
58+
negative_elbo = logq - logp
59+
return negative_elbo
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Sequence
4+
from typing import cast
5+
6+
from pymc import SymbolicRandomVariable
7+
from pymc.distributions.shape_utils import change_dist_size
8+
from pytensor import config
9+
from pytensor import tensor as pt
10+
from pytensor.graph import FunctionGraph, ancestors, vectorize_graph
11+
from pytensor.tensor import TensorLike, TensorVariable
12+
from pytensor.tensor.basic import infer_shape_db
13+
from pytensor.tensor.random.op import RandomVariable
14+
from pytensor.tensor.rewriting.shape import ShapeFeature
15+
16+
17+
def vectorize_random_graph(
18+
graph: Sequence[TensorVariable], batch_draws: TensorLike
19+
) -> list[TensorVariable]:
20+
# Find the root random nodes
21+
rvs = tuple(
22+
var
23+
for var in ancestors(graph)
24+
if (
25+
var.owner is not None
26+
and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable)
27+
)
28+
)
29+
rvs_set = set(rvs)
30+
root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set))
31+
32+
# Vectorize graph by vectorizing root RVs
33+
batch_draws = pt.as_tensor(batch_draws, dtype=int)
34+
vectorized_replacements = {
35+
root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True)
36+
for root_rv in root_rvs
37+
}
38+
return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements))
39+
40+
41+
def get_symbolic_rv_shapes(
42+
rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True
43+
) -> tuple[TensorVariable, ...]:
44+
# TODO: Move me to pymc.pytensorf, this is needed often
45+
46+
rv_shapes = [rv.shape for rv in rvs]
47+
shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True)
48+
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
49+
infer_shape_db.default_query.rewrite(shape_fg)
50+
rv_shapes = shape_fg.outputs
51+
52+
if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))):
53+
raise ValueError(f"rv_shapes still depend the following rvs {overlap}")
54+
55+
return cast(tuple[TensorVariable, ...], tuple(rv_shapes))
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Protocol
2+
3+
import numpy as np
4+
5+
from pymc import Model, compile
6+
from pymc.pytensorf import rewrite_pregrad
7+
from pytensor import tensor as pt
8+
9+
from pymc_extras.inference.advi.autoguide import AutoGuideModel
10+
from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq
11+
from pymc_extras.inference.advi.pytensorf import vectorize_random_graph
12+
13+
14+
class TrainingFn(Protocol):
15+
def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ...
16+
17+
18+
def compile_svi_training_fn(
19+
model: Model, guide: AutoGuideModel, stick_the_landing: bool = True, **compile_kwargs
20+
) -> TrainingFn:
21+
draws = pt.scalar("draws", dtype=int)
22+
params = guide.params
23+
24+
logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing)
25+
26+
scalar_negative_elbo = advi_objective(logp, logq)
27+
[negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws)
28+
negative_elbo = negative_elbo_draws.mean(axis=0)
29+
30+
negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params)
31+
32+
if "trust_input" not in compile_kwargs:
33+
compile_kwargs["trust_input"] = True
34+
35+
f_loss_dloss = compile(
36+
inputs=[draws, *params], outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs
37+
)
38+
39+
return f_loss_dloss

tests/inference/advi/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)