diff --git a/pymc/dims/distributions/transforms.py b/pymc/dims/distributions/transforms.py index 8f49d2a16f..3425fbe69e 100644 --- a/pymc/dims/distributions/transforms.py +++ b/pymc/dims/distributions/transforms.py @@ -54,6 +54,34 @@ def log_jac_det(self, value, *inputs): log_odds_transform = LogOddsTransform() +class SimplexTransform(DimTransform): + name = "simplex" + + def __init__(self, dim: str): + self.core_dim = dim + + def forward(self, value, *inputs): + log_value = ptx.math.log(value) + N = value.sizes[self.core_dim].astype(value.dtype) + shift = log_value.sum(self.core_dim) / N + return log_value.isel({self.core_dim: slice(None, -1)}) - shift + + def backward(self, value, *inputs): + value = ptx.concat([value, -value.sum(self.core_dim)], dim=self.core_dim) + exp_value_max = ptx.math.exp(value - value.max(self.core_dim)) + return exp_value_max / exp_value_max.sum(self.core_dim) + + def log_jac_det(self, value, *inputs): + N = value.sizes[self.core_dim] + 1 + N = N.astype(value.dtype) + sum_value = value.sum(self.core_dim) + value_sum_expanded = value + sum_value + value_sum_expanded = ptx.concat([value_sum_expanded, 0], dim=self.core_dim) + logsumexp_value_expanded = ptx.math.logsumexp(value_sum_expanded, dim=self.core_dim) + res = ptx.math.log(N) + (N * sum_value) - (N * logsumexp_value_expanded) + return res + + class ZeroSumTransform(DimTransform): name = "zerosum" diff --git a/pymc/dims/distributions/vector.py b/pymc/dims/distributions/vector.py index 0ad834c8a8..7107712e67 100644 --- a/pymc/dims/distributions/vector.py +++ b/pymc/dims/distributions/vector.py @@ -19,7 +19,7 @@ from pytensor.xtensor import random as pxr from pymc.dims.distributions.core import VectorDimDistribution -from pymc.dims.distributions.transforms import ZeroSumTransform +from pymc.dims.distributions.transforms import SimplexTransform, ZeroSumTransform from pymc.distributions.multivariate import ZeroSumNormalRV from pymc.util import UNSET @@ -63,6 +63,61 @@ def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs): return super().dist([p], core_dims=core_dims, **kwargs) +class Dirichlet(VectorDimDistribution): + """Dirichlet distribution. + + Parameters + ---------- + a : xtensor_like, optional + Probabilities of each category. Must sum to 1 along the core dimension. + core_dims : str + The core dimension of the distribution, which represents the categories. + The dimension must be present in `p` or `logit_p`. + **kwargs + Other keyword arguments used to define the distribution. + + Returns + ------- + XTensorVariable + An xtensor variable representing the categorical distribution. + The output does not contain the core dimension, as it is absorbed into the distribution. + + + """ + + xrv_op = ptxr.dirichlet + + @classmethod + def __new__( + cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs + ): + if core_dims is not None: + if isinstance(core_dims, tuple | list): + [core_dims] = core_dims + + # Create default_transform + if observed is None and default_transform is UNSET: + default_transform = SimplexTransform(dim=core_dims) + + # If the user didn't specify dims, take it from core_dims + # We need them to be forwarded to dist in the `dim_lenghts` argument + # if dims is None and core_dims is not None: + # dims = (..., *core_dims) + + return super().__new__( + *args, + core_dims=core_dims, + dims=dims, + default_transform=default_transform, + observed=observed, + **kwargs, + ) + + @classmethod + def dist(cls, a, *, core_dims=None, **kwargs): + return super().dist([a], core_dims=core_dims, **kwargs) + + class MvNormal(VectorDimDistribution): """Multivariate Normal distribution. diff --git a/tests/dims/distributions/test_vector.py b/tests/dims/distributions/test_vector.py index 3a57453b48..8cfdadb372 100644 --- a/tests/dims/distributions/test_vector.py +++ b/tests/dims/distributions/test_vector.py @@ -19,7 +19,7 @@ import pymc.distributions as regular_distributions from pymc import Model -from pymc.dims import Categorical, MvNormal, ZeroSumNormal +from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph @@ -40,6 +40,27 @@ def test_categorical(): assert_equivalent_logp_graph(model, reference_model) +def test_dirichlet(): + coords = {"a": range(3), "b": range(2)} + alpha = pt.as_tensor([1, 2, 3]) + + alpha_xr = as_xtensor(alpha, dims=("b",)) + + with Model(coords=coords) as model: + Dirichlet("x", a=alpha_xr, core_dims="b", dims=("a", "b")) + + with Model(coords=coords) as reference_model: + regular_distributions.Dirichlet("x", a=alpha, dims=("a", "b")) + + assert_equivalent_random_graph(model, reference_model) + + # logp graphs end up different, but they mean the same thing + np.testing.assert_allclose( + model.compile_logp()(model.initial_point()), + reference_model.compile_logp()(reference_model.initial_point()), + ) + + def test_mvnormal(): coords = {"a": range(3), "b": range(2)} mu = pt.as_tensor([1, 2])