Skip to content

Commit d2e72b5

Browse files
committed
alpha min value
1 parent 93c4a60 commit d2e72b5

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,13 @@ def logp(value, r, alpha):
512512
)
513513

514514
def logcdf(value, r, alpha):
515-
# TODO: Math may not be correct here
516515
logcdf = r * (pt.log(value) - pt.log(alpha + value))
517516

518517
return check_parameters(
519518
logcdf,
520519
r > 0,
521-
alpha > 0,
522-
msg="s > 0, alpha > 0",
520+
alpha > 0.6181, # alpha must be greater than 0.6181 for convergence
521+
msg="r > 0, alpha > 0",
523522
)
524523

525524
def support_point(rv, size, r, alpha):
@@ -530,10 +529,7 @@ def support_point(rv, size, r, alpha):
530529
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
531530
the geometric link function and round to ensure an integer value.
532531
"""
533-
# E[lambda] = r/alpha for Gamma(r, 1/alpha)
534-
# p = 1 - exp(-lambda) for geometric
535-
# E[T] = 1/p for geometric
536-
mean = pt.ceil(pt.exp(alpha/r)) # Conservative upper bound
532+
mean = pt.ceil(pt.exp(alpha/r))
537533

538534
if not rv_size_is_none(size):
539535
mean = pt.full(size, mean)

tests/distributions/test_discrete.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def test_logp(self):
214214
class TestGrassiaIIGeometric:
215215
class TestRandomVariable(BaseTestDistributionRandom):
216216
pymc_dist = GrassiaIIGeometric
217-
pymc_dist_params = {"r": 1.0, "alpha": 2.0}
218-
expected_rv_op_params = {"r": 1.0, "alpha": 2.0}
217+
pymc_dist_params = {"r": .5, "alpha": 2.0}
218+
expected_rv_op_params = {"r": .5, "alpha": 2.0}
219219
tests_to_run = [
220220
"check_pymc_params_match_rv_op",
221221
"check_rv_size",
@@ -289,11 +289,11 @@ def test_sampling_consistency(self):
289289
trace = pm.sample(chains=1, draws=1000, random_seed=42).posterior
290290

291291
samples = trace["x"].values.flatten()
292-
292+
293293
# Check basic properties of samples
294294
assert np.all(samples > 0) # All values should be positive
295295
assert np.all(samples.astype(int) == samples) # All values should be integers
296-
296+
297297
# Check mean and variance are reasonable
298298
# (exact values depend on the parameterization)
299299
assert 0 < np.mean(samples) < np.inf

0 commit comments

Comments
 (0)