Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions pymc/dims/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,34 @@ def log_jac_det(self, value, *inputs):
log_odds_transform = LogOddsTransform()


class SimplexTransform(DimTransform):
name = "simplex"

def __init__(self, dim: str):
self.core_dim = dim

def forward(self, value, *inputs):
log_value = ptx.math.log(value)
N = value.sizes[self.core_dim].astype(value.dtype)
shift = log_value.sum(self.core_dim) / N
return log_value.isel({self.core_dim: slice(None, -1)}) - shift

def backward(self, value, *inputs):
value = ptx.concat([value, -value.sum(self.core_dim)], dim=self.core_dim)
exp_value_max = ptx.math.exp(value - value.max(self.core_dim))
return exp_value_max / exp_value_max.sum(self.core_dim)

def log_jac_det(self, value, *inputs):
N = value.sizes[self.core_dim] + 1
N = N.astype(value.dtype)
sum_value = value.sum(self.core_dim)
value_sum_expanded = value + sum_value
value_sum_expanded = ptx.concat([value_sum_expanded, 0], dim=self.core_dim)
logsumexp_value_expanded = ptx.math.logsumexp(value_sum_expanded, dim=self.core_dim)
res = ptx.math.log(N) + (N * sum_value) - (N * logsumexp_value_expanded)
return res


class ZeroSumTransform(DimTransform):
name = "zerosum"

Expand Down
57 changes: 56 additions & 1 deletion pymc/dims/distributions/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytensor.xtensor import random as pxr

from pymc.dims.distributions.core import VectorDimDistribution
from pymc.dims.distributions.transforms import ZeroSumTransform
from pymc.dims.distributions.transforms import SimplexTransform, ZeroSumTransform
from pymc.distributions.multivariate import ZeroSumNormalRV
from pymc.util import UNSET

Expand Down Expand Up @@ -63,6 +63,61 @@ def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs):
return super().dist([p], core_dims=core_dims, **kwargs)


class Dirichlet(VectorDimDistribution):
"""Dirichlet distribution.

Parameters
----------
a : xtensor_like, optional
Probabilities of each category. Must sum to 1 along the core dimension.
core_dims : str
The core dimension of the distribution, which represents the categories.
The dimension must be present in `p` or `logit_p`.
**kwargs
Other keyword arguments used to define the distribution.

Returns
-------
XTensorVariable
An xtensor variable representing the categorical distribution.
The output does not contain the core dimension, as it is absorbed into the distribution.


"""

xrv_op = ptxr.dirichlet

@classmethod
def __new__(
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
):
if core_dims is not None:
if isinstance(core_dims, tuple | list):
[core_dims] = core_dims

# Create default_transform
if observed is None and default_transform is UNSET:
default_transform = SimplexTransform(dim=core_dims)

# If the user didn't specify dims, take it from core_dims
# We need them to be forwarded to dist in the `dim_lenghts` argument
# if dims is None and core_dims is not None:
# dims = (..., *core_dims)

return super().__new__(
*args,
core_dims=core_dims,
dims=dims,
default_transform=default_transform,
observed=observed,
**kwargs,
)

@classmethod
def dist(cls, a, *, core_dims=None, **kwargs):
return super().dist([a], core_dims=core_dims, **kwargs)


class MvNormal(VectorDimDistribution):
"""Multivariate Normal distribution.

Expand Down
23 changes: 22 additions & 1 deletion tests/dims/distributions/test_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pymc.distributions as regular_distributions

from pymc import Model
from pymc.dims import Categorical, MvNormal, ZeroSumNormal
from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph


Expand All @@ -40,6 +40,27 @@ def test_categorical():
assert_equivalent_logp_graph(model, reference_model)


def test_dirichlet():
coords = {"a": range(3), "b": range(2)}
alpha = pt.as_tensor([1, 2, 3])

alpha_xr = as_xtensor(alpha, dims=("b",))

with Model(coords=coords) as model:
Dirichlet("x", a=alpha_xr, core_dims="b", dims=("a", "b"))

with Model(coords=coords) as reference_model:
regular_distributions.Dirichlet("x", a=alpha, dims=("a", "b"))

assert_equivalent_random_graph(model, reference_model)

# logp graphs end up different, but they mean the same thing
np.testing.assert_allclose(
model.compile_logp()(model.initial_point()),
reference_model.compile_logp()(reference_model.initial_point()),
)


def test_mvnormal():
coords = {"a": range(3), "b": range(2)}
mu = pt.as_tensor([1, 2])
Expand Down
Loading