Skip to content

Commit 52682eb

Browse files
authored
Parametrize Binomial and Categorical distributions via logit_p (#5637)
1 parent 5b31ec7 commit 52682eb

File tree

3 files changed

+70
-6
lines changed

3 files changed

+70
-6
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ This includes API changes we did not warn about since at least `3.11.0` (2021-01
133133
- Adding support for blackjax's NUTS sampler `pymc.sampling_jax` (see [#5477](ihttps://github.com/pymc-devs/pymc/pull/5477))
134134
- `pymc.sampling_jax` samplers support `log_likelihood`, `observed_data`, and `sample_stats` in returned InferenceData object (see [#5189](https://github.com/pymc-devs/pymc/pull/5189))
135135
- Adding support for `pm.Deterministic` in `pymc.sampling_jax` (see [#5182](https://github.com/pymc-devs/pymc/pull/5182))
136+
- Added an alternative parametrization, `logit_p` to `pm.Binomial` and `pm.Categorical` distributions (see [5637](https://github.com/pymc-devs/pymc/pull/5637)).
136137
- ...
137138

138139

pymc/distributions/discrete.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,25 @@ class Binomial(Discrete):
106106
107107
Parameters
108108
----------
109-
n: int
109+
n : int
110110
Number of Bernoulli trials (n >= 0).
111-
p: float
111+
p : float
112112
Probability of success in each trial (0 < p < 1).
113+
logit_p : float
114+
Alternative log odds for the probability of success.
113115
"""
114116
rv_op = binomial
115117

116118
@classmethod
117-
def dist(cls, n, p, *args, **kwargs):
119+
def dist(cls, n, p=None, logit_p=None, *args, **kwargs):
120+
if p is not None and logit_p is not None:
121+
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
122+
elif p is None and logit_p is None:
123+
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")
124+
125+
if logit_p is not None:
126+
p = at.sigmoid(logit_p)
127+
118128
n = at.as_tensor_variable(intX(n))
119129
p = at.as_tensor_variable(floatX(p))
120130
return super().dist([n, p], **kwargs)
@@ -1245,14 +1255,24 @@ class Categorical(Discrete):
12451255
12461256
Parameters
12471257
----------
1248-
p: array of floats
1258+
p : array of floats
12491259
p > 0 and the elements of p must sum to 1. They will be automatically
12501260
rescaled otherwise.
1261+
logit_p : float
1262+
Alternative log odds for the probability of success.
12511263
"""
12521264
rv_op = categorical
12531265

12541266
@classmethod
1255-
def dist(cls, p, **kwargs):
1267+
def dist(cls, p=None, logit_p=None, **kwargs):
1268+
if p is not None and logit_p is not None:
1269+
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
1270+
elif p is None and logit_p is None:
1271+
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")
1272+
1273+
if logit_p is not None:
1274+
p = pm.math.softmax(logit_p, axis=-1)
1275+
12561276
if isinstance(p, np.ndarray) or isinstance(p, list):
12571277
if (np.asarray(p) < 0).any():
12581278
raise ValueError(f"Negative `p` parameters are not valid, got: {p}")

pymc/tests/test_distributions_random.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def random_polyagamma(*args, **kwargs):
3838
raise RuntimeError("polyagamma package is not installed!")
3939

4040

41-
from scipy.special import expit
41+
from scipy.special import expit, softmax
4242

4343
import pymc as pm
4444

@@ -1006,6 +1006,25 @@ class TestBinomial(BaseTestDistributionRandom):
10061006
checks_to_run = ["check_pymc_params_match_rv_op"]
10071007

10081008

1009+
class TestLogitBinomial(BaseTestDistributionRandom):
1010+
pymc_dist = pm.Binomial
1011+
pymc_dist_params = {"n": 100, "logit_p": 0.5}
1012+
expected_rv_op_params = {"n": 100, "p": expit(0.5)}
1013+
tests_to_run = ["check_pymc_params_match_rv_op"]
1014+
1015+
@pytest.mark.parametrize(
1016+
"n, p, logit_p, expected",
1017+
[
1018+
(5, None, None, "Must specify either p or logit_p."),
1019+
(5, 0.5, 0.5, "Can't specify both p and logit_p."),
1020+
],
1021+
)
1022+
def test_binomial_init_fail(self, n, p, logit_p, expected):
1023+
with pm.Model() as model:
1024+
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
1025+
pm.Binomial("x", n=n, p=p, logit_p=logit_p)
1026+
1027+
10091028
class TestNegativeBinomial(BaseTestDistributionRandom):
10101029
pymc_dist = pm.NegativeBinomial
10111030
pymc_dist_params = {"n": 100, "p": 0.33}
@@ -1411,6 +1430,30 @@ class TestCategorical(BaseTestDistributionRandom):
14111430
]
14121431

14131432

1433+
class TestLogitCategorical(BaseTestDistributionRandom):
1434+
pymc_dist = pm.Categorical
1435+
pymc_dist_params = {"logit_p": np.array([[0.28, 0.62, 0.10], [0.28, 0.62, 0.10]])}
1436+
expected_rv_op_params = {
1437+
"p": softmax(np.array([[0.28, 0.62, 0.10], [0.28, 0.62, 0.10]]), axis=-1)
1438+
}
1439+
tests_to_run = [
1440+
"check_pymc_params_match_rv_op",
1441+
"check_rv_size",
1442+
]
1443+
1444+
@pytest.mark.parametrize(
1445+
"p, logit_p, expected",
1446+
[
1447+
(None, None, "Must specify either p or logit_p."),
1448+
(0.5, 0.5, "Can't specify both p and logit_p."),
1449+
],
1450+
)
1451+
def test_categorical_init_fail(self, p, logit_p, expected):
1452+
with pm.Model() as model:
1453+
with pytest.raises(ValueError, match=f"Incompatible parametrization. {expected}"):
1454+
pm.Categorical("x", p=p, logit_p=logit_p)
1455+
1456+
14141457
class TestGeometric(BaseTestDistributionRandom):
14151458
pymc_dist = pm.Geometric
14161459
pymc_dist_params = {"p": 0.9}

0 commit comments

Comments
 (0)