|
| 1 | +from dataclasses import dataclass, field |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import pytensor.tensor as pt |
| 5 | + |
| 6 | +from pymc.distributions import Normal |
| 7 | +from pymc.logprob.basic import conditional_logp |
| 8 | +from pymc.model.core import Deterministic, Model |
| 9 | +from pytensor import graph_replace |
| 10 | +from pytensor.gradient import disconnected_grad |
| 11 | +from pytensor.graph.basic import Variable |
| 12 | + |
| 13 | +from pymc_extras.inference.advi.pytensorf import get_symbolic_rv_shapes |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class AutoGuideModel: |
| 18 | + model: Model |
| 19 | + params_init_values: dict[Variable, np.ndarray] |
| 20 | + name_to_param: dict[str, Variable] = field(init=False) |
| 21 | + |
| 22 | + def __post_init__(self): |
| 23 | + object.__setattr__( |
| 24 | + self, |
| 25 | + "name_to_param", |
| 26 | + {x.name: x for x in self.params_init_values.keys()}, |
| 27 | + ) |
| 28 | + |
| 29 | + @property |
| 30 | + def params(self) -> tuple[Variable, ...]: |
| 31 | + return tuple(self.params_init_values.keys()) |
| 32 | + |
| 33 | + def __getitem__(self, name: str) -> Variable: |
| 34 | + return self.name_to_param[name] |
| 35 | + |
| 36 | + def stochastic_logq(self, stick_the_landing: bool = True) -> pt.TensorVariable: |
| 37 | + """Returns a graph representing the logp of the guide model, evaluated under draws from its random variables.""" |
| 38 | + logp_terms = conditional_logp( |
| 39 | + {rv: rv for rv in self.model.deterministics}, |
| 40 | + warn_rvs=False, |
| 41 | + ) |
| 42 | + logq = pt.sum([logp_term.sum() for logp_term in logp_terms.values()]) |
| 43 | + |
| 44 | + if stick_the_landing: |
| 45 | + # Detach variational parameters from the gradient computation of logq |
| 46 | + repl = {p: disconnected_grad(p) for p in self.params} |
| 47 | + logq = graph_replace(logq, repl) |
| 48 | + |
| 49 | + return logq |
| 50 | + |
| 51 | + |
| 52 | +def AutoDiagonalNormal(model: Model) -> AutoGuideModel: |
| 53 | + """ |
| 54 | + Create a guide model for ADVI with a mean-field normal distribution. |
| 55 | +
|
| 56 | + A guide model is a variational distribution that approximates the posterior distribution of the model's free |
| 57 | + random variables. In this case, we use a mean-field normal distribution, which assumes that the free random |
| 58 | + variables are independent and normally distributed. For details, see _[1]. |
| 59 | +
|
| 60 | + For each free random variable in the model, we create a corresponding random variable in the guide model with a |
| 61 | + normal distribution. The mean and standard deviation of each normal distribution are parameterized by learnable |
| 62 | + parameters (loc and scale), which are initialized to small random values. |
| 63 | +
|
| 64 | + Parameters |
| 65 | + ---------- |
| 66 | + model : Model |
| 67 | + The probabilistic model for which to create the guide. |
| 68 | +
|
| 69 | + Returns |
| 70 | + ------- |
| 71 | + guide_model : AutoGuideModel |
| 72 | + An AutoGuideModel containing the guide model and the initial values for its parameters. |
| 73 | +
|
| 74 | + References |
| 75 | + ---------- |
| 76 | + .. [1] Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, and David M. Blei. Automatic Differentiation |
| 77 | + Variational Inference. Journal of Machine Learning Research, 18(14):1–45, 2017. |
| 78 | + """ |
| 79 | + coords = model.coords |
| 80 | + free_rvs = model.free_RVs |
| 81 | + |
| 82 | + free_rv_shapes = dict(zip(free_rvs, get_symbolic_rv_shapes(free_rvs))) |
| 83 | + params_init_values = {} |
| 84 | + |
| 85 | + with Model(coords=coords) as guide_model: |
| 86 | + for rv in free_rvs: |
| 87 | + loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) |
| 88 | + scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) |
| 89 | + # TODO: Make these customizable |
| 90 | + params_init_values[loc] = pt.random.uniform(-1, 1, size=free_rv_shapes[rv]).eval() |
| 91 | + params_init_values[scale] = pt.full(free_rv_shapes[rv], 0.1).eval() |
| 92 | + |
| 93 | + z = Normal( |
| 94 | + f"{rv.name}_z", |
| 95 | + mu=0, |
| 96 | + sigma=1, |
| 97 | + shape=free_rv_shapes[rv], |
| 98 | + ) |
| 99 | + Deterministic( |
| 100 | + rv.name, |
| 101 | + loc + pt.softplus(scale) * z, |
| 102 | + dims=model.named_vars_to_dims.get(rv.name, None), |
| 103 | + ) |
| 104 | + |
| 105 | + return AutoGuideModel(guide_model, params_init_values) |
0 commit comments