Skip to content

Commit 9e90fae

Browse files
committed
Adding Dirichlet and SimplexTransform to pymc.dims
1 parent 340e403 commit 9e90fae

File tree

3 files changed

+106
-2
lines changed

3 files changed

+106
-2
lines changed

pymc/dims/distributions/transforms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,34 @@ def log_jac_det(self, value, *inputs):
5454
log_odds_transform = LogOddsTransform()
5555

5656

57+
class SimplexTransform(DimTransform):
58+
name = "simplex"
59+
60+
def __init__(self, dim: str):
61+
self.core_dim = dim
62+
63+
def forward(self, value, *inputs):
64+
log_value = ptx.math.log(value)
65+
N = value.sizes[self.core_dim].astype(value.dtype)
66+
shift = log_value.sum(self.core_dim) / N
67+
return log_value.isel({self.core_dim: slice(None, -1)}) - shift
68+
69+
def backward(self, value, *inputs):
70+
value = ptx.concat([value, -value.sum(self.core_dim)], dim=self.core_dim)
71+
exp_value_max = ptx.math.exp(value - value.max(self.core_dim))
72+
return exp_value_max / exp_value_max.sum(self.core_dim)
73+
74+
def log_jac_det(self, value, *inputs):
75+
N = value.sizes[self.core_dim] + 1
76+
N = N.astype(value.dtype)
77+
sum_value = value.sum(self.core_dim)
78+
value_sum_expanded = value + sum_value
79+
value_sum_expanded = ptx.concat([value_sum_expanded, 0], dim=self.core_dim)
80+
logsumexp_value_expanded = ptx.math.logsumexp(value_sum_expanded, dim=self.core_dim)
81+
res = ptx.math.log(N) + (N * sum_value) - (N * logsumexp_value_expanded)
82+
return res
83+
84+
5785
class ZeroSumTransform(DimTransform):
5886
name = "zerosum"
5987

pymc/dims/distributions/vector.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.xtensor import random as pxr
2020

2121
from pymc.dims.distributions.core import VectorDimDistribution
22-
from pymc.dims.distributions.transforms import ZeroSumTransform
22+
from pymc.dims.distributions.transforms import SimplexTransform, ZeroSumTransform
2323
from pymc.distributions.multivariate import ZeroSumNormalRV
2424
from pymc.util import UNSET
2525

@@ -63,6 +63,61 @@ def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs):
6363
return super().dist([p], core_dims=core_dims, **kwargs)
6464

6565

66+
class Dirichlet(VectorDimDistribution):
67+
"""Dirichlet distribution.
68+
69+
Parameters
70+
----------
71+
a : xtensor_like, optional
72+
Probabilities of each category. Must sum to 1 along the core dimension.
73+
core_dims : str
74+
The core dimension of the distribution, which represents the categories.
75+
The dimension must be present in `p` or `logit_p`.
76+
**kwargs
77+
Other keyword arguments used to define the distribution.
78+
79+
Returns
80+
-------
81+
XTensorVariable
82+
An xtensor variable representing the categorical distribution.
83+
The output does not contain the core dimension, as it is absorbed into the distribution.
84+
85+
86+
"""
87+
88+
xrv_op = ptxr.dirichlet
89+
90+
@classmethod
91+
def __new__(
92+
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
93+
):
94+
if core_dims is not None:
95+
if isinstance(core_dims, tuple | list):
96+
[core_dims] = core_dims
97+
98+
# Create default_transform
99+
if observed is None and default_transform is UNSET:
100+
default_transform = SimplexTransform(dim=core_dims)
101+
102+
# If the user didn't specify dims, take it from core_dims
103+
# We need them to be forwarded to dist in the `dim_lenghts` argument
104+
# if dims is None and core_dims is not None:
105+
# dims = (..., *core_dims)
106+
107+
return super().__new__(
108+
*args,
109+
core_dims=core_dims,
110+
dims=dims,
111+
default_transform=default_transform,
112+
observed=observed,
113+
**kwargs,
114+
)
115+
116+
@classmethod
117+
def dist(cls, a, *, core_dims=None, **kwargs):
118+
return super().dist([a], core_dims=core_dims, **kwargs)
119+
120+
66121
class MvNormal(VectorDimDistribution):
67122
"""Multivariate Normal distribution.
68123

tests/dims/distributions/test_vector.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pymc.distributions as regular_distributions
2020

2121
from pymc import Model
22-
from pymc.dims import Categorical, MvNormal, ZeroSumNormal
22+
from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal
2323
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2424

2525

@@ -40,6 +40,27 @@ def test_categorical():
4040
assert_equivalent_logp_graph(model, reference_model)
4141

4242

43+
def test_dirichlet():
44+
coords = {"a": range(3), "b": range(2)}
45+
alpha = pt.as_tensor([1, 2, 3])
46+
47+
alpha_xr = as_xtensor(alpha, dims=("b",))
48+
49+
with Model(coords=coords) as model:
50+
Dirichlet("x", a=alpha_xr, core_dims="b", dims=("a", "b"))
51+
52+
with Model(coords=coords) as reference_model:
53+
regular_distributions.Dirichlet("x", a=alpha, dims=("a", "b"))
54+
55+
assert_equivalent_random_graph(model, reference_model)
56+
57+
# logp graphs end up different, but they mean the same thing
58+
np.testing.assert_allclose(
59+
model.compile_logp()(model.initial_point()),
60+
reference_model.compile_logp()(reference_model.initial_point()),
61+
)
62+
63+
4364
def test_mvnormal():
4465
coords = {"a": range(3), "b": range(2)}
4566
mu = pt.as_tensor([1, 2])

0 commit comments

Comments
 (0)