Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
681 changes: 681 additions & 0 deletions notebooks/ADVI Guide API.ipynb

Large diffs are not rendered by default.

Empty file.
105 changes: 105 additions & 0 deletions pymc_extras/inference/advi/autoguide.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from dataclasses import dataclass, field

import numpy as np
import pytensor.tensor as pt

from pymc.distributions import Normal
from pymc.logprob.basic import conditional_logp
from pymc.model.core import Deterministic, Model
from pytensor import graph_replace
from pytensor.gradient import disconnected_grad
from pytensor.graph.basic import Variable

from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes


@dataclass(frozen=True)
class AutoGuideModel:
model: Model
params_init_values: dict[Variable, np.ndarray]
name_to_param: dict[str, Variable] = field(init=False)

def __post_init__(self):
object.__setattr__(
self,
"name_to_param",
{x.name: x for x in self.params_init_values.keys()},
)

@property
def params(self) -> tuple[Variable, ...]:
return tuple(self.params_init_values.keys())

def __getitem__(self, name: str) -> Variable:
return self.name_to_param[name]

def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable:
"""Returns a graph representing the logp of the guide model, evaluated under draws from its random variables."""
logp_terms = conditional_logp(
{rv: rv for rv in self.model.deterministics},
warn_rvs=False,
)
logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()])

if stick_the_landing:
# Detach variational parameters from the gradient computation of logq
repl = {p: disconnected_grad(p) for p in self.params}
logq = graph_replace(logq, repl)

return logq


def AutoDiagonalNormal(model: Model) -> AutoGuideModel:
"""
Create a guide model for ADVI with a mean-field normal distribution.

A guide model is a variational distribution that approximates the posterior distribution of the model's free
random variables. In this case, we use a mean-field normal distribution, which assumes that the free random
variables are independent and normally distributed. For details, see _[1].

For each free random variable in the model, we create a corresponding random variable in the guide model with a
normal distribution. The mean and standard deviation of each normal distribution are parameterized by learnable
parameters (loc and scale), which are initialized to small random values.

Parameters
----------
model : Model
The probabilistic model for which to create the guide.

Returns
-------
guide_model : AutoGuideModel
An AutoGuideModel containing the guide model and the initial values for its parameters.

References
----------
.. [1] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation
Variational Inference. Journal of Machine Learning Research, 18(14):1–45, 2017.
"""
coords = model.coords
free_rvs = model.free_RVs

free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs)))
params_init_values = {}

with Model(coords=coords) as guide_model:
for rv in free_rvs:
loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape)
scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape)
# TODO: Make these customizable
params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval()
params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval()

z = Normal(
f"{rv.name}_z",
mu=0,
sigma=1,
shape=free_rv_shapes[rv],
)
Deterministic(
rv.name,
loc + pt.softplus(scale) * z,
dims=model.named_vars_to_dims.get(rv.name, None),
)

return AutoGuideModel(guide_model, params_init_values)
59 changes: 59 additions & 0 deletions pymc_extras/inference/advi/objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from pymc import Model
from pytensor import graph_replace
from pytensor.tensor import TensorVariable

from pymc_extras.inference.advi.autoguide import AutoGuideModel


def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True):
"""
Compute the log probability of the model and the guide.

Parameters
----------
model : Model
The probabilistic model.
guide : AutoGuideModel
The variational guide.
stick_the_landing : bool, optional
Whether to use the stick-the-landing (STL) gradient estimator, by default True.
The STL estimator has lower gradient variance by removing the score function term
from the gradient. When True, gradients are stopped from flowing through logq.

Returns
-------
logp : TensorVariable
Log probability of the model.
logq : TensorVariable
Log probability of the guide.
"""

inputs_to_guide_rvs = {
model_value_var: guide.model[rv.name]
for rv, model_value_var in model.rvs_to_values.items()
if rv not in model.observed_RVs
}

logp = graph_replace(model.logp(), inputs_to_guide_rvs)
logq = guide.stochastic_logq(stick_the_landing=stick_the_landing)

return logp, logq


def advi_objective(logp: TensorVariable, logq: TensorVariable):
"""Compute the negative ELBO objective for ADVI.

Parameters
----------
logp : TensorVariable
Log probability of the model.
logq : TensorVariable
Log probability of the guide.

Returns
-------
TensorVariable
The negative ELBO.
"""
negative_elbo = logq - logp
return negative_elbo
55 changes: 55 additions & 0 deletions pymc_extras/inference/advi/pytensorf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import cast

from pymc import SymbolicRandomVariable
from pymc.distributions.shape_utils import change_dist_size
from pytensor import config
from pytensor import tensor as pt
from pytensor.graph import FunctionGraph, ancestors, vectorize_graph
from pytensor.tensor import TensorLike, TensorVariable
from pytensor.tensor.basic import infer_shape_db
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.rewriting.shape import ShapeFeature


def vectorize_random_graph(
graph: Sequence[TensorVariable], batch_draws: TensorLike
) -> list[TensorVariable]:
# Find the root random nodes
rvs = tuple(
var
for var in ancestors(graph)
if (
var.owner is not None
and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable)
)
)
rvs_set = set(rvs)
root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set))

# Vectorize graph by vectorizing root RVs
batch_draws = pt.as_tensor(batch_draws, dtype=int)
vectorized_replacements = {
root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True)
for root_rv in root_rvs
}
return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements))


def get_symbolic_rv_shapes(
rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True
) -> tuple[TensorVariable, ...]:
# TODO: Move me to pymc.pytensorf, this is needed often

rv_shapes = [rv.shape for rv in rvs]
shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True)
with config.change_flags(optdb__max_use_ratio=10, cxx=""):
infer_shape_db.default_query.rewrite(shape_fg)
rv_shapes = shape_fg.outputs

if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))):
raise ValueError(f"rv_shapes still depend the following rvs {overlap}")

return cast(tuple[TensorVariable, ...], tuple(rv_shapes))
39 changes: 39 additions & 0 deletions pymc_extras/inference/advi/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import Protocol

import numpy as np

from pymc import Model, compile
from pymc.pytensorf import rewrite_pregrad
from pytensor import tensor as pt

from pymc_extras.inference.advi.autoguide import AutoGuideModel
from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq
from pymc_extras.inference.advi.pytensorf import vectorize_random_graph


class TrainingFn(Protocol):
def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ...


def compile_svi_training_fn(
model: Model, guide: AutoGuideModel, stick_the_landing: bool = True, **compile_kwargs
) -> TrainingFn:
draws = pt.scalar("draws", dtype=int)
params = guide.params

logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing)

scalar_negative_elbo = advi_objective(logp, logq)
[negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws)
negative_elbo = negative_elbo_draws.mean(axis=0)

negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params)

if "trust_input" not in compile_kwargs:
compile_kwargs["trust_input"] = True

f_loss_dloss = compile(
inputs=[draws, *params], outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs
)

return f_loss_dloss
Empty file.
Loading