Skip to content

Commit 3d44733

Browse files
committed
add logcdf
1 parent e3743a3 commit 3d44733

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ def dist(cls, alpha, beta, *args, **kwargs):
489489
return super().dist([alpha, beta], *args, **kwargs)
490490

491491
def logp(value, alpha, beta):
492+
"""From Expression (5) on p.6 of Fader & Hardie (2007)"""
493+
492494
logp = betaln(alpha + 1, beta + value - 1) - betaln(alpha, beta)
493495

494496
logp = pt.switch(
@@ -510,6 +512,18 @@ def logp(value, alpha, beta):
510512
msg="alpha > 0, beta > 0",
511513
)
512514

515+
def logcdf(value, alpha, beta):
516+
"""Adapted from Expression (6) on p.6 of Fader & Hardie (2007)"""
517+
# survival function from paper
518+
logS = (
519+
pt.gammaln(beta + value)
520+
- pt.gammaln(beta)
521+
+ pt.gammaln(alpha + beta)
522+
- pt.gammaln(alpha + beta + value)
523+
)
524+
# log(1-exp())
525+
return pt.log1mexp(logS)
526+
513527
def support_point(rv, size, alpha, beta):
514528
"""Calculate a reasonable starting point for sampling.
515529

tests/distributions/test_discrete.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
BaseTestDistributionRandom,
2525
Domain,
2626
I,
27+
Nat,
2728
Rplus,
2829
assert_support_point_is_expected,
2930
check_logp,
31+
check_selfconsistency_discrete_logcdf,
3032
discrete_random_tester,
3133
)
3234
from pytensor import config
@@ -241,8 +243,8 @@ def test_random_basic_properties(self):
241243

242244
def test_random_edge_cases(self):
243245
"""Test with very small and large beta and alpha values"""
244-
beta_vals = [20.0, 0.07, 18.0, 0.05]
245-
alpha_vals = [20.0, 14.0, 0.06, 0.05]
246+
beta_vals = [20.0, 0.08, 18.0, 0.06]
247+
alpha_vals = [20.0, 14.0, 0.07, 0.06]
246248

247249
for beta in beta_vals:
248250
for alpha in alpha_vals:
@@ -299,6 +301,13 @@ def test_logp_matches_paper(self):
299301
logp_fn = pytensor.function([value, alpha_, beta_], logp)
300302
np.testing.assert_allclose(logp_fn(t_vec, alpha, beta), expected, rtol=1e-2)
301303

304+
def test_logcdf(self):
305+
check_selfconsistency_discrete_logcdf(
306+
distribution=ShiftedBetaGeometric,
307+
domain=Nat,
308+
paramdomains={"alpha": Rplus, "beta": Rplus},
309+
)
310+
302311
@pytest.mark.parametrize(
303312
"alpha, beta, size, expected_shape",
304313
[

0 commit comments

Comments
 (0)