Skip to content

Commit 416737e

Browse files
committed
Fix logcdf of DiscreteUniform at lower bound
logcdf(x=lower) != -inf
1 parent f4c82c1 commit 416737e

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

pymc/distributions/discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,11 +1052,11 @@ def logp(value, lower, upper):
10521052

10531053
def logcdf(value, lower, upper):
10541054
res = pt.switch(
1055-
pt.le(value, lower),
1055+
pt.lt(value, lower),
10561056
-np.inf,
10571057
pt.switch(
10581058
pt.lt(value, upper),
1059-
pt.log(pt.minimum(pt.floor(value), upper) - lower + 1) - pt.log(upper - lower + 1),
1059+
pt.log(pt.floor(value) - lower + 1) - pt.log(upper - lower + 1),
10601060
0,
10611061
),
10621062
)

tests/distributions/test_discrete.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@
4141
Nat,
4242
NatSmall,
4343
R,
44-
Rdunif,
4544
Rplus,
46-
Rplusdunif,
4745
Runif,
4846
Simplex,
4947
Unit,
@@ -95,32 +93,36 @@ def orderedprobit_logpdf(value, eta, cutpoints):
9593

9694

9795
class TestMatchesScipy:
98-
def test_discrete_unif(self):
96+
def test_discrete_uniform(self):
97+
# Choose domain/paramdomain so we test edge cases as well
98+
test_domain = Domain([-np.inf, -10, -1, 0, 1, 10, np.inf], dtype="int64")
99+
test_paramdomain = Domain([-np.inf, 0, 10, np.inf], dtype="int64")
99100
check_logp(
100101
pm.DiscreteUniform,
101-
Rdunif,
102-
{"lower": -Rplusdunif, "upper": Rplusdunif},
102+
test_domain,
103+
{"lower": -test_paramdomain, "upper": test_paramdomain},
103104
lambda value, lower, upper: st.randint.logpmf(value, lower, upper + 1),
104105
skip_paramdomain_outside_edge_test=True,
105106
)
106107
check_logcdf(
107108
pm.DiscreteUniform,
108-
Rdunif,
109-
{"lower": -Rplusdunif, "upper": Rplusdunif},
109+
test_domain,
110+
{"lower": -test_paramdomain, "upper": test_paramdomain},
110111
lambda value, lower, upper: st.randint.logcdf(value, lower, upper + 1),
111112
skip_paramdomain_outside_edge_test=True,
112113
)
113114
check_selfconsistency_discrete_logcdf(
114115
pm.DiscreteUniform,
115116
Domain([-10, 0, 10], "int64"),
116-
{"lower": -Rplusdunif, "upper": Rplusdunif},
117+
{"lower": -test_paramdomain, "upper": test_paramdomain},
117118
)
118119
check_icdf(
119120
pm.DiscreteUniform,
120-
{"lower": -Rplusdunif, "upper": Rplusdunif},
121+
{"lower": -test_paramdomain, "upper": test_paramdomain},
121122
lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1),
122123
skip_paramdomain_outside_edge_test=True,
123124
)
125+
124126
# Custom logp / logcdf check for invalid parameters
125127
invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
126128
with pytensor.config.change_flags(mode=Mode("py")):

0 commit comments

Comments
 (0)