Skip to content

Commit d5a98c6

Browse files
committed
WIP test log_p
1 parent 11a8cac commit d5a98c6

File tree

2 files changed

+20
-48
lines changed

2 files changed

+20
-48
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def rng_fn(cls, rng, alpha, beta, size):
429429
sbg = ShiftedBetaGeometricRV()
430430

431431

432-
# TODO: Update docstrings for sBG, including plotting code
433432
class ShiftedBetaGeometric(Discrete):
434433
r"""Shifted Beta-Geometric distribution.
435434
@@ -517,8 +516,11 @@ def logp(value, alpha, beta):
517516

518517
logp = pt.switch(
519518
pt.or_(
520-
alpha <= 0,
521-
beta <= 0,
519+
pt.or_(
520+
alpha <= 0,
521+
beta <= 0,
522+
),
523+
pt.lt(value, 1),
522524
),
523525
-np.inf,
524526
logp,
@@ -535,13 +537,13 @@ def support_point(rv, size, alpha, beta):
535537
"""Calculate a reasonable starting point for sampling.
536538
537539
For the Shifted Beta-Geometric distribution, we use a point estimate based on
538-
the expected value of both mixture components.
540+
the expected value of the mixture components.
539541
540542
"""
541543
geo_mean = pt.ceil(
542-
pt.reciprocal( # expected value of the geometric distribution
544+
pt.reciprocal(
543545
alpha / (alpha + beta) # expected value of the beta distribution
544-
)
546+
) # expected value of the geometric distribution
545547
)
546548
if not rv_size_is_none(size):
547549
geo_mean = pt.full(size, geo_mean)

tests/distributions/test_discrete.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -215,37 +215,17 @@ def test_logp(self):
215215
class TestShiftedBetaGeometric:
216216
class TestRandomVariable(BaseTestDistributionRandom):
217217
pymc_dist = ShiftedBetaGeometric
218-
pymc_dist_params = {"alpha": 2.0, "beta": 3.0}
219-
expected_rv_op_params = {"alpha": 2.0, "beta": 3.0}
218+
pymc_dist_params = {"alpha": 1.0, "beta": 1.0}
219+
expected_rv_op_params = {"alpha": 1.0, "beta": 1.0}
220220
tests_to_run = [
221221
"check_pymc_params_match_rv_op",
222222
"check_rv_size",
223223
]
224224

225-
# TODO: Adapt this to ShiftedBetaGeometric and delete random_moments tests?
226-
# def test_random_matches_geometric(self):
227-
# discrete_random_tester(
228-
# dist=self.pymc_dist,
229-
# paramdomains={"theta": Rplus, "alpha": Domain([0], edges=(None, None))},
230-
# ref_rand=lambda mu, lam, size: scipy.stats.geometric.rvs(theta, size=size),
231-
# )
232-
233-
# @pytest.mark.parametrize("mu", (2.5, 20, 50))
234-
# def test_random_lam_expected_moments(self, mu):
235-
# lam = np.array([-0.9, -0.7, -0.2, 0, 0.2, 0.7, 0.9])
236-
# dist = self.pymc_dist.dist(mu=mu, lam=lam, size=(10_000, len(lam)))
237-
# draws = dist.eval()
238-
239-
# expected_mean = mu / (1 - lam)
240-
# np.testing.assert_allclose(draws.mean(0), expected_mean, rtol=1e-1)
241-
242-
# expected_std = np.sqrt(mu / (1 - lam) ** 3)
243-
# np.testing.assert_allclose(draws.std(0), expected_std, rtol=1e-1)
244-
245225
def test_random_basic_properties(self):
246226
"""Test basic random sampling properties"""
247227
# Test with standard parameter values
248-
alpha_vals = [0.5, 1.0, 2.0]
228+
alpha_vals = [1.0, 0.5, 2.0]
249229
beta_vals = [0.5, 1.0, 2.0]
250230

251231
for alpha in alpha_vals:
@@ -277,16 +257,14 @@ def test_random_edge_cases(self):
277257
assert np.var(draws) > 0
278258

279259
@pytest.mark.parametrize(
280-
"alpha,beta",
260+
"alpha",
281261
[
282-
(0.5, 1.0),
283-
(1.0, np.array([2.0, 1.0])),
284-
(np.array([1.0, 2.0]), 1.0),
285-
(np.array([2.0, 0.5]), np.array([1.0, 2.0])),
262+
(0.5, 1.0, 10.0),
286263
],
287264
)
288-
def test_random_moments(self, alpha, beta):
289-
dist = self.pymc_dist.dist(alpha=alpha, beta=beta, size=10_000)
265+
def test_random_moments(self, alpha):
266+
beta = np.array([0.5, 1.0, 10.0])
267+
dist = self.pymc_dist.dist(alpha=alpha, beta=beta, size=(10_000, len(beta)))
290268
draws = dist.eval()
291269

292270
assert np.all(draws > 0)
@@ -300,7 +278,7 @@ def test_logp(self):
300278
beta = pt.scalar("beta")
301279
value = pt.vector("value", dtype="int64")
302280

303-
# Check out-of-bounds values
281+
# Compile logp function for testing
304282
dist = ShiftedBetaGeometric.dist(alpha, beta)
305283
logp = pm.logp(dist, value)
306284
logp_fn = pytensor.function([value, alpha, beta], logp)
@@ -311,21 +289,13 @@ def test_logp(self):
311289
assert not np.any(np.isnan(logp_vals))
312290
assert np.all(np.isfinite(logp_vals))
313291

314-
# Check out-of-bounds values
315-
value = pt.scalar("value")
316-
logp = pm.logp(ShiftedBetaGeometric.dist(alpha, beta), value)
317-
logp_fn = pytensor.function([value, alpha, beta], logp)
318-
319-
logp_fn(-1, alpha=5, beta=0) == -np.inf
320-
logp_fn(9, alpha=5, beta=-1) == -np.inf
292+
assert logp_fn(-1, alpha=5, beta=1) == -np.inf
321293

322-
# Check mu/lam restrictions
294+
# Check alpha/beta restrictions
323295
with pytest.raises(ParameterValueError):
324-
logp_fn(1, alpha=1, beta=2)
325-
296+
logp_fn(1, alpha=-1, beta=2)
326297
with pytest.raises(ParameterValueError):
327298
logp_fn(1, alpha=0, beta=0)
328-
329299
with pytest.raises(ParameterValueError):
330300
logp_fn(1, alpha=1, beta=-1)
331301

0 commit comments

Comments
 (0)