Skip to content

Commit 93c4a60

Browse files
committed
unit tests
1 parent 48e93f3 commit 93c4a60

File tree

3 files changed

+141
-5
lines changed

3 files changed

+141
-5
lines changed

pymc_extras/distributions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
BetaNegativeBinomial,
2323
GeneralizedPoisson,
2424
Skellam,
25+
GrassiaIIGeometric,
2526
)
2627
from pymc_extras.distributions.histogram_utils import histogram_approximation
2728
from pymc_extras.distributions.multivariate import R2D2M2CP

pymc_extras/distributions/discrete.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import pymc as pm
1717

18+
from pymc.distributions.distribution import Discrete
1819
from pymc.distributions.dist_math import betaln, check_parameters, factln, logpow
1920
from pymc.distributions.shape_utils import rv_size_is_none
2021
from pytensor import tensor as pt
@@ -441,12 +442,11 @@ def sim_data(lam):
441442
g2g = GrassiaIIGeometricRV()
442443

443444

444-
class GrassiaIIGeometric(UnitContinuous):
445+
class GrassiaIIGeometric(Discrete):
445446
r"""Grassia(II)-Geometric distribution.
446447
447-
This distribution is a flexible alternative to the Geometric distribution for the
448-
number of trials until a discrete event, and can be easily extended to support both static
449-
and time-varying covariates.
448+
This distribution is a flexible alternative to the Geometric distribution for the number of trials until a
449+
discrete event, and can be extended to support both static and time-varying covariates.
450450
451451
Hardie and Fader describe this distribution with the following PMF and survival functions in [1]_:
452452
@@ -520,4 +520,22 @@ def logcdf(value, r, alpha):
520520
r > 0,
521521
alpha > 0,
522522
msg="s > 0, alpha > 0",
523-
)
523+
)
524+
525+
def support_point(rv, size, r, alpha):
526+
"""Calculate a reasonable starting point for sampling.
527+
528+
For the GrassiaIIGeometric distribution, we use a point estimate based on
529+
the expected value of the mixing distribution. Since the mixing distribution
530+
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
531+
the geometric link function and round to ensure an integer value.
532+
"""
533+
# E[lambda] = r/alpha for Gamma(r, 1/alpha)
534+
# p = 1 - exp(-lambda) for geometric
535+
# E[T] = 1/p for geometric
536+
mean = pt.ceil(pt.exp(alpha/r)) # Conservative upper bound
537+
538+
if not rv_size_is_none(size):
539+
mean = pt.full(size, mean)
540+
541+
return mean

tests/distributions/test_discrete.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
BetaNegativeBinomial,
3535
GeneralizedPoisson,
3636
Skellam,
37+
GrassiaIIGeometric,
3738
)
3839

3940

@@ -208,3 +209,119 @@ def test_logp(self):
208209
{"mu1": Rplus_small, "mu2": Rplus_small},
209210
lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
210211
)
212+
213+
214+
class TestGrassiaIIGeometric:
215+
class TestRandomVariable(BaseTestDistributionRandom):
216+
pymc_dist = GrassiaIIGeometric
217+
pymc_dist_params = {"r": 1.0, "alpha": 2.0}
218+
expected_rv_op_params = {"r": 1.0, "alpha": 2.0}
219+
tests_to_run = [
220+
"check_pymc_params_match_rv_op",
221+
"check_rv_size",
222+
]
223+
224+
def test_random_basic_properties(self):
225+
discrete_random_tester(
226+
dist=self.pymc_dist,
227+
paramdomains={"r": Rplus, "alpha": Rplus},
228+
ref_rand=lambda r, alpha, size: np.random.geometric(
229+
1 - np.exp(-np.random.gamma(r, 1/alpha, size=size)), size=size
230+
),
231+
)
232+
233+
@pytest.mark.parametrize("r,alpha", [
234+
(0.5, 1.0),
235+
(1.0, 2.0),
236+
(2.0, 0.5),
237+
(5.0, 1.0),
238+
])
239+
def test_random_moments(self, r, alpha):
240+
dist = self.pymc_dist.dist(r=r, alpha=alpha, size=10_000)
241+
draws = dist.eval()
242+
243+
# Check that all values are positive integers
244+
assert np.all(draws > 0)
245+
assert np.all(draws.astype(int) == draws)
246+
247+
# Check that values are reasonably distributed
248+
# Note: Exact moments are complex for this distribution
249+
# so we just check basic properties
250+
assert np.mean(draws) > 0
251+
assert np.var(draws) > 0
252+
253+
def test_logp_basic(self):
254+
r = pt.scalar("r")
255+
alpha = pt.scalar("alpha")
256+
value = pt.vector("value", dtype="int64")
257+
258+
logp = pm.logp(GrassiaIIGeometric.dist(r, alpha), value)
259+
logp_fn = pytensor.function([value, r, alpha], logp)
260+
261+
# Test basic properties of logp
262+
test_value = np.array([1, 2, 3, 4, 5])
263+
test_r = 1.0
264+
test_alpha = 1.0
265+
266+
logp_vals = logp_fn(test_value, test_r, test_alpha)
267+
assert not np.any(np.isnan(logp_vals))
268+
assert np.all(np.isfinite(logp_vals))
269+
270+
# Test invalid values
271+
assert logp_fn(np.array([0]), test_r, test_alpha) == np.inf # Value must be > 0
272+
273+
with pytest.raises(TypeError):
274+
logp_fn(np.array([1.5]), test_r, test_alpha) == -np.inf # Value must be integer
275+
276+
# Test parameter restrictions
277+
with pytest.raises(ParameterValueError):
278+
logp_fn(np.array([1]), -1.0, test_alpha) # r must be > 0
279+
280+
with pytest.raises(ParameterValueError):
281+
logp_fn(np.array([1]), test_r, -1.0) # alpha must be > 0
282+
283+
def test_sampling_consistency(self):
284+
"""Test that sampling from the distribution produces reasonable results"""
285+
r = 2.0
286+
alpha = 1.0
287+
with pm.Model():
288+
x = GrassiaIIGeometric("x", r=r, alpha=alpha)
289+
trace = pm.sample(chains=1, draws=1000, random_seed=42).posterior
290+
291+
samples = trace["x"].values.flatten()
292+
293+
# Check basic properties of samples
294+
assert np.all(samples > 0) # All values should be positive
295+
assert np.all(samples.astype(int) == samples) # All values should be integers
296+
297+
# Check mean and variance are reasonable
298+
# (exact values depend on the parameterization)
299+
assert 0 < np.mean(samples) < np.inf
300+
assert 0 < np.var(samples) < np.inf
301+
302+
@pytest.mark.parametrize(
303+
"r, alpha, size, expected_shape",
304+
[
305+
(1.0, 1.0, None, ()), # Scalar output
306+
([1.0, 2.0], 1.0, None, (2,)), # Vector output from r
307+
(1.0, [1.0, 2.0], None, (2,)), # Vector output from alpha
308+
(1.0, 1.0, (3, 2), (3, 2)), # Explicit size
309+
],
310+
)
311+
def test_support_point(self, r, alpha, size, expected_shape):
312+
"""Test that support_point returns reasonable values with correct shapes"""
313+
with pm.Model() as model:
314+
GrassiaIIGeometric("x", r=r, alpha=alpha, size=size)
315+
316+
init_point = model.initial_point()["x"]
317+
318+
# Check shape
319+
assert init_point.shape == expected_shape
320+
321+
# Check values are positive integers
322+
assert np.all(init_point > 0)
323+
assert np.all(init_point.astype(int) == init_point)
324+
325+
# Check values are finite and reasonable
326+
assert np.all(np.isfinite(init_point))
327+
assert np.all(init_point < 1e6) # Should not be extremely large

0 commit comments

Comments
 (0)