-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
Description
This pertains to the logprob submodule. During logprob derivation of an expression like
import numpy as np
import pymc as pm
x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1) # Censored normal
pm.logp(x, np.zeros((2, 5)))We create a MeasurableClip that replaces x, when we identify we can derive the logprob as a simple censored pdf. This MeasurableClip however does not retain any of the meta-information about the type of RV that it encapsulates (ndim_supp, dtype, support axis).
pymc/pymc/logprob/censoring.py
Lines 61 to 67 in a0d6ba0
| class MeasurableClip(MeasurableElemwise): | |
| """A placeholder used to specify a log-likelihood for a clipped RV sub-graph.""" | |
| valid_scalar_types = (Clip,) | |
| measurable_clip = MeasurableClip(scalar_clip) |
If we wanted to further compose the graph, we would find issues when some operation needs to know this information
x_raw = pm.Normal.dist(np.arange(5), shape=(2, 5))
x = pm.math.clip(x_raw, -1, 1) # Censored normal<
x = x.T
pm.logp(x, np.zeros((5, 2))) # NotImplementedError: PyMC could not infer logp of input variable.This happens because to infer the logprob of a transposed (dimshuffled) variable, we need to know the original support dimensionality and support axis (which is always the rightmost for pure distributions):
Lines 285 to 298 in a0d6ba0
| # We can only apply this rewrite directly to `RandomVariable`s, as those are | |
| # the only `Op`s for which we always know the support axis. Other measurable | |
| # variables can have arbitrary support axes (e.g., if they contain separate | |
| # `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s | |
| # should still be supported as long as the `DimShuffle`s can be merged/ | |
| # lifted towards the base RandomVariable. | |
| # TODO: If we include the support axis as meta information in each | |
| # intermediate MeasurableVariable, we can lift this restriction. | |
| if not ( | |
| base_var.owner | |
| and isinstance(base_var.owner.op, RandomVariable) | |
| and base_var not in rv_map_feature.rv_values | |
| ): | |
| return None # pragma: no cover |
If we propagated that information to the MeasurableClip (ndim_supp=0, support_axis=None, dtype="mixed"), the Dimshuffle rewrite could be safely used and we could derive the logp for the second example. This is also useful for other rewrites...
More context in aesara-devs/aeppl#183