Skip to content

Commit bcb23cc

Browse files
committed
Merge upstream/main to incorporate latest fixes
2 parents 4dad014 + 2af4eb0 commit bcb23cc

File tree

6 files changed

+117
-12
lines changed

6 files changed

+117
-12
lines changed

pymc/dims/distributions/transforms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,34 @@ def log_jac_det(self, value, *inputs):
5454
log_odds_transform = LogOddsTransform()
5555

5656

57+
class SimplexTransform(DimTransform):
58+
name = "simplex"
59+
60+
def __init__(self, dim: str):
61+
self.core_dim = dim
62+
63+
def forward(self, value, *inputs):
64+
log_value = ptx.math.log(value)
65+
N = value.sizes[self.core_dim].astype(value.dtype)
66+
shift = log_value.sum(self.core_dim) / N
67+
return log_value.isel({self.core_dim: slice(None, -1)}) - shift
68+
69+
def backward(self, value, *inputs):
70+
value = ptx.concat([value, -value.sum(self.core_dim)], dim=self.core_dim)
71+
exp_value_max = ptx.math.exp(value - value.max(self.core_dim))
72+
return exp_value_max / exp_value_max.sum(self.core_dim)
73+
74+
def log_jac_det(self, value, *inputs):
75+
N = value.sizes[self.core_dim] + 1
76+
N = N.astype(value.dtype)
77+
sum_value = value.sum(self.core_dim)
78+
value_sum_expanded = value + sum_value
79+
value_sum_expanded = ptx.concat([value_sum_expanded, 0], dim=self.core_dim)
80+
logsumexp_value_expanded = ptx.math.logsumexp(value_sum_expanded, dim=self.core_dim)
81+
res = ptx.math.log(N) + (N * sum_value) - (N * logsumexp_value_expanded)
82+
return res
83+
84+
5785
class ZeroSumTransform(DimTransform):
5886
name = "zerosum"
5987

pymc/dims/distributions/vector.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytensor.xtensor import random as pxr
2020

2121
from pymc.dims.distributions.core import VectorDimDistribution
22-
from pymc.dims.distributions.transforms import ZeroSumTransform
22+
from pymc.dims.distributions.transforms import SimplexTransform, ZeroSumTransform
2323
from pymc.distributions.multivariate import ZeroSumNormalRV
2424
from pymc.util import UNSET
2525

@@ -63,6 +63,61 @@ def dist(cls, p=None, *, logit_p=None, core_dims=None, **kwargs):
6363
return super().dist([p], core_dims=core_dims, **kwargs)
6464

6565

66+
class Dirichlet(VectorDimDistribution):
67+
"""Dirichlet distribution.
68+
69+
Parameters
70+
----------
71+
a : xtensor_like, optional
72+
Probabilities of each category. Must sum to 1 along the core dimension.
73+
core_dims : str
74+
The core dimension of the distribution, which represents the categories.
75+
The dimension must be present in `p` or `logit_p`.
76+
**kwargs
77+
Other keyword arguments used to define the distribution.
78+
79+
Returns
80+
-------
81+
XTensorVariable
82+
An xtensor variable representing the categorical distribution.
83+
The output does not contain the core dimension, as it is absorbed into the distribution.
84+
85+
86+
"""
87+
88+
xrv_op = ptxr.dirichlet
89+
90+
@classmethod
91+
def __new__(
92+
cls, *args, core_dims=None, dims=None, default_transform=UNSET, observed=None, **kwargs
93+
):
94+
if core_dims is not None:
95+
if isinstance(core_dims, tuple | list):
96+
[core_dims] = core_dims
97+
98+
# Create default_transform
99+
if observed is None and default_transform is UNSET:
100+
default_transform = SimplexTransform(dim=core_dims)
101+
102+
# If the user didn't specify dims, take it from core_dims
103+
# We need them to be forwarded to dist in the `dim_lenghts` argument
104+
# if dims is None and core_dims is not None:
105+
# dims = (..., *core_dims)
106+
107+
return super().__new__(
108+
*args,
109+
core_dims=core_dims,
110+
dims=dims,
111+
default_transform=default_transform,
112+
observed=observed,
113+
**kwargs,
114+
)
115+
116+
@classmethod
117+
def dist(cls, a, *, core_dims=None, **kwargs):
118+
return super().dist([a], core_dims=core_dims, **kwargs)
119+
120+
66121
class MvNormal(VectorDimDistribution):
67122
"""Multivariate Normal distribution.
68123

pymc/distributions/continuous.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2575,6 +2575,11 @@ class ChiSquared:
25752575
----------
25762576
nu : tensor_like of float
25772577
Degrees of freedom (nu > 0).
2578+
2579+
Notes
2580+
-----
2581+
This is implemented as a special case of the Gamma distribution.
2582+
:math:`\chi^2(\nu) = \text{Gamma}(\alpha=\nu/2, \beta=1/2)`
25782583
"""
25792584

25802585
def __new__(cls, name, nu, **kwargs):
@@ -3601,7 +3606,7 @@ def icdf(value, mu, s):
36013606
class LogitNormalRV(SymbolicRandomVariable):
36023607
name = "logit_normal"
36033608
extended_signature = "[rng],[size],(),()->[rng],()"
3604-
_print_name = ("logitNormal", "\\operatorname{logitNormal}")
3609+
_print_name = ("LogitNormal", "\\operatorname{LogitNormal}")
36053610

36063611
@classmethod
36073612
def rv_op(cls, mu, sigma, *, size=None, rng=None):

pymc/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def logcdf(value, p):
393393
class DiscreteWeibullRV(SymbolicRandomVariable):
394394
name = "discrete_weibull"
395395
extended_signature = "[rng],[size],(),()->[rng],()"
396-
_print_name = ("dWeibull", "\\operatorname{dWeibull}")
396+
_print_name = ("DiscreteWeibull", "\\operatorname{DiscreteWeibull}")
397397

398398
@classmethod
399399
def rv_op(cls, q, beta, *, size=None, rng=None):

tests/dims/distributions/test_vector.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pymc.distributions as regular_distributions
2020

2121
from pymc import Model
22-
from pymc.dims import Categorical, MvNormal, ZeroSumNormal
22+
from pymc.dims import Categorical, Dirichlet, MvNormal, ZeroSumNormal
2323
from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph
2424

2525

@@ -40,6 +40,27 @@ def test_categorical():
4040
assert_equivalent_logp_graph(model, reference_model)
4141

4242

43+
def test_dirichlet():
44+
coords = {"a": range(3), "b": range(2)}
45+
alpha = pt.as_tensor([1, 2, 3])
46+
47+
alpha_xr = as_xtensor(alpha, dims=("b",))
48+
49+
with Model(coords=coords) as model:
50+
Dirichlet("x", a=alpha_xr, core_dims="b", dims=("a", "b"))
51+
52+
with Model(coords=coords) as reference_model:
53+
regular_distributions.Dirichlet("x", a=alpha, dims=("a", "b"))
54+
55+
assert_equivalent_random_graph(model, reference_model)
56+
57+
# logp graphs end up different, but they mean the same thing
58+
np.testing.assert_allclose(
59+
model.compile_logp()(model.initial_point()),
60+
reference_model.compile_logp()(reference_model.initial_point()),
61+
)
62+
63+
4364
def test_mvnormal():
4465
coords = {"a": range(3), "b": range(2)}
4566
mu = pt.as_tensor([1, 2])

tests/distributions/test_censored.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,13 @@ def test_censored_logcdf_continuous(self):
137137

138138
# No censoring
139139
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
140-
with pytest.warns(RuntimeWarning, match=match_str):
141-
censored_eval = logcdf(censored_norm, eval_points).eval()
140+
censored_eval = logcdf(censored_norm, eval_points).eval()
142141
np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored)
143142

144143
# Left censoring
145144
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
146145
expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
147-
with pytest.warns(RuntimeWarning, match=match_str):
148-
censored_eval = logcdf(censored_norm, eval_points).eval()
146+
censored_eval = logcdf(censored_norm, eval_points).eval()
149147
np.testing.assert_allclose(
150148
censored_eval,
151149
expected_left,
@@ -155,8 +153,7 @@ def test_censored_logcdf_continuous(self):
155153
# Right censoring
156154
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
157155
expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored)
158-
with pytest.warns(RuntimeWarning, match=match_str):
159-
censored_eval = logcdf(censored_norm, eval_points).eval()
156+
censored_eval = logcdf(censored_norm, eval_points).eval()
160157
np.testing.assert_allclose(
161158
censored_eval,
162159
expected_right,
@@ -167,8 +164,7 @@ def test_censored_logcdf_continuous(self):
167164
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
168165
expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
169166
expected_interval = np.where(eval_points >= 1, 0.0, expected_interval)
170-
with pytest.warns(RuntimeWarning, match=match_str):
171-
censored_eval = logcdf(censored_norm, eval_points).eval()
167+
censored_eval = logcdf(censored_norm, eval_points).eval()
172168
np.testing.assert_allclose(
173169
censored_eval,
174170
expected_interval,

0 commit comments

Comments
 (0)