-
Notifications
You must be signed in to change notification settings - Fork 81
Move VI autoguide work from pymc repo #636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| from dataclasses import dataclass, field | ||
|
|
||
| import numpy as np | ||
| import pytensor.tensor as pt | ||
|
|
||
| from pymc.distributions import Normal | ||
| from pymc.logprob.basic import conditional_logp | ||
| from pymc.model.core import Deterministic, Model | ||
| from pytensor import graph_replace | ||
| from pytensor.gradient import disconnected_grad | ||
| from pytensor.graph.basic import Variable | ||
|
|
||
| from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
| class AutoGuideModel: | ||
| model: Model | ||
| params_init_values: dict[Variable, np.ndarray] | ||
| name_to_param: dict[str, Variable] = field(init=False) | ||
|
|
||
| def __post_init__(self): | ||
| object.__setattr__( | ||
| self, | ||
| "name_to_param", | ||
| {x.name: x for x in self.params_init_values.keys()}, | ||
| ) | ||
|
|
||
| @property | ||
| def params(self) -> tuple[Variable, ...]: | ||
| return tuple(self.params_init_values.keys()) | ||
|
|
||
| def __getitem__(self, name: str) -> Variable: | ||
| return self.name_to_param[name] | ||
|
|
||
| def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable: | ||
| """Returns a graph representing the logp of the guide model, evaluated under draws from its random variables.""" | ||
| logp_terms = conditional_logp( | ||
| {rv: rv for rv in self.model.deterministics}, | ||
| warn_rvs=False, | ||
| ) | ||
| logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()]) | ||
|
|
||
| if stick_the_landing: | ||
| # Detach variational parameters from the gradient computation of logq | ||
| repl = {p: disconnected_grad(p) for p in self.params} | ||
| logq = graph_replace(logq, repl) | ||
|
|
||
| return logq | ||
|
|
||
|
|
||
| def AutoDiagonalNormal(model: Model) -> AutoGuideModel: | ||
| """ | ||
| Create a guide model for ADVI with a mean-field normal distribution. | ||
|
|
||
| A guide model is a variational distribution that approximates the posterior distribution of the model's free | ||
| random variables. In this case, we use a mean-field normal distribution, which assumes that the free random | ||
| variables are independent and normally distributed. For details, see _[1]. | ||
|
|
||
| For each free random variable in the model, we create a corresponding random variable in the guide model with a | ||
| normal distribution. The mean and standard deviation of each normal distribution are parameterized by learnable | ||
| parameters (loc and scale), which are initialized to small random values. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Model | ||
| The probabilistic model for which to create the guide. | ||
|
|
||
| Returns | ||
| ------- | ||
| guide_model : AutoGuideModel | ||
| An AutoGuideModel containing the guide model and the initial values for its parameters. | ||
|
|
||
| References | ||
| ---------- | ||
| .. [1] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation | ||
| Variational Inference. Journal of Machine Learning Research, 18(14):1–45, 2017. | ||
| """ | ||
| coords = model.coords | ||
| free_rvs = model.free_RVs | ||
|
|
||
| free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs))) | ||
| params_init_values = {} | ||
|
|
||
| with Model(coords=coords) as guide_model: | ||
| for rv in free_rvs: | ||
| loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) | ||
| scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) | ||
| # TODO: Make these customizable | ||
| params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval() | ||
| params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval() | ||
|
|
||
| z = Normal( | ||
| f"{rv.name}_z", | ||
| mu=0, | ||
| sigma=1, | ||
| shape=free_rv_shapes[rv], | ||
| ) | ||
| Deterministic( | ||
| rv.name, | ||
| loc + pt.softplus(scale) * z, | ||
| dims=model.named_vars_to_dims.get(rv.name, None), | ||
| ) | ||
|
|
||
| return AutoGuideModel(guide_model, params_init_values) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| from pymc import Model | ||
| from pytensor import graph_replace | ||
| from pytensor.tensor import TensorVariable | ||
|
|
||
| from pymc_extras.inference.advi.autoguide import AutoGuideModel | ||
|
|
||
|
|
||
| def get_logp_logq(model: Model, guide: AutoGuideModel, stick_the_landing: bool = True): | ||
| """ | ||
| Compute the log probability of the model and the guide. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| model : Model | ||
| The probabilistic model. | ||
| guide : AutoGuideModel | ||
| The variational guide. | ||
| stick_the_landing : bool, optional | ||
| Whether to use the stick-the-landing (STL) gradient estimator, by default True. | ||
| The STL estimator has lower gradient variance by removing the score function term | ||
| from the gradient. When True, gradients are stopped from flowing through logq. | ||
|
|
||
| Returns | ||
| ------- | ||
| logp : TensorVariable | ||
| Log probability of the model. | ||
| logq : TensorVariable | ||
| Log probability of the guide. | ||
| """ | ||
|
|
||
| inputs_to_guide_rvs = { | ||
| model_value_var: guide.model[rv.name] | ||
| for rv, model_value_var in model.rvs_to_values.items() | ||
| if rv not in model.observed_RVs | ||
| } | ||
|
|
||
| logp = graph_replace(model.logp(), inputs_to_guide_rvs) | ||
| logq = guide.stochastic_logq(stick_the_landing=stick_the_landing) | ||
|
|
||
| return logp, logq | ||
|
|
||
|
|
||
| def advi_objective(logp: TensorVariable, logq: TensorVariable): | ||
| """Compute the negative ELBO objective for ADVI. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| logp : TensorVariable | ||
| Log probability of the model. | ||
| logq : TensorVariable | ||
| Log probability of the guide. | ||
|
|
||
| Returns | ||
| ------- | ||
| TensorVariable | ||
| The negative ELBO. | ||
| """ | ||
| negative_elbo = logq - logp | ||
| return negative_elbo |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from typing import cast | ||
|
|
||
| from pymc import SymbolicRandomVariable | ||
| from pymc.distributions.shape_utils import change_dist_size | ||
| from pytensor import config | ||
| from pytensor import tensor as pt | ||
| from pytensor.graph import FunctionGraph, ancestors, vectorize_graph | ||
| from pytensor.tensor import TensorLike, TensorVariable | ||
| from pytensor.tensor.basic import infer_shape_db | ||
| from pytensor.tensor.random.op import RandomVariable | ||
| from pytensor.tensor.rewriting.shape import ShapeFeature | ||
|
|
||
|
|
||
| def vectorize_random_graph( | ||
| graph: Sequence[TensorVariable], batch_draws: TensorLike | ||
| ) -> list[TensorVariable]: | ||
| # Find the root random nodes | ||
| rvs = tuple( | ||
| var | ||
| for var in ancestors(graph) | ||
| if ( | ||
| var.owner is not None | ||
| and isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable) | ||
| ) | ||
| ) | ||
| rvs_set = set(rvs) | ||
| root_rvs = tuple(rv for rv in rvs if not (set(rv.owner.inputs) & rvs_set)) | ||
|
|
||
| # Vectorize graph by vectorizing root RVs | ||
| batch_draws = pt.as_tensor(batch_draws, dtype=int) | ||
| vectorized_replacements = { | ||
| root_rv: change_dist_size(root_rv, new_size=batch_draws, expand=True) | ||
| for root_rv in root_rvs | ||
| } | ||
| return cast(list[TensorVariable], vectorize_graph(graph, replace=vectorized_replacements)) | ||
|
|
||
|
|
||
| def get_symbolic_rv_shapes( | ||
| rvs: Sequence[TensorVariable], raise_if_rvs_in_graph: bool = True | ||
| ) -> tuple[TensorVariable, ...]: | ||
| # TODO: Move me to pymc.pytensorf, this is needed often | ||
|
|
||
| rv_shapes = [rv.shape for rv in rvs] | ||
| shape_fg = FunctionGraph(outputs=rv_shapes, features=[ShapeFeature()], clone=True) | ||
| with config.change_flags(optdb__max_use_ratio=10, cxx=""): | ||
| infer_shape_db.default_query.rewrite(shape_fg) | ||
| rv_shapes = shape_fg.outputs | ||
|
|
||
| if raise_if_rvs_in_graph and (overlap := (set(rvs) & set(ancestors(rv_shapes)))): | ||
| raise ValueError(f"rv_shapes still depend the following rvs {overlap}") | ||
|
|
||
| return cast(tuple[TensorVariable, ...], tuple(rv_shapes)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| from typing import Protocol | ||
|
|
||
| import numpy as np | ||
|
|
||
| from pymc import Model, compile | ||
| from pymc.pytensorf import rewrite_pregrad | ||
| from pytensor import tensor as pt | ||
|
|
||
| from pymc_extras.inference.advi.autoguide import AutoGuideModel | ||
| from pymc_extras.inference.advi.objective import advi_objective, get_logp_logq | ||
| from pymc_extras.inference.advi.pytensorf import vectorize_random_graph | ||
|
|
||
|
|
||
| class TrainingFn(Protocol): | ||
| def __call__(self, draws: int, *params: np.ndarray) -> tuple[np.ndarray, ...]: ... | ||
|
|
||
|
|
||
| def compile_svi_training_fn( | ||
| model: Model, guide: AutoGuideModel, stick_the_landing: bool = True, **compile_kwargs | ||
| ) -> TrainingFn: | ||
| draws = pt.scalar("draws", dtype=int) | ||
| params = guide.params | ||
|
|
||
| logp, logq = get_logp_logq(model, guide, stick_the_landing=stick_the_landing) | ||
|
|
||
| scalar_negative_elbo = advi_objective(logp, logq) | ||
| [negative_elbo_draws] = vectorize_random_graph([scalar_negative_elbo], batch_draws=draws) | ||
| negative_elbo = negative_elbo_draws.mean(axis=0) | ||
|
|
||
| negative_elbo_grads = pt.grad(rewrite_pregrad(negative_elbo), wrt=params) | ||
|
|
||
| if "trust_input" not in compile_kwargs: | ||
| compile_kwargs["trust_input"] = True | ||
|
|
||
| f_loss_dloss = compile( | ||
| inputs=[draws, *params], outputs=[negative_elbo, *negative_elbo_grads], **compile_kwargs | ||
| ) | ||
|
|
||
| return f_loss_dloss | ||
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.