Skip to content

Commit 026f182

Browse files
committed
small lam value tests
1 parent 8685005 commit 026f182

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -426,11 +426,17 @@ def rng_fn(cls, rng, r, alpha, size):
426426
lam = rng.gamma(shape=r, scale=1/alpha, size=size)
427427

428428
def sim_data(lam):
429-
# TODO: To support time-varying covariates, covariate vector may need to be added
430-
p = 1 - np.exp(-lam)
431-
429+
# Handle numerical stability for very small lambda values
430+
p = np.where(
431+
lam < 0.001,
432+
lam, # For small lambda, p ≈ lambda
433+
1 - np.exp(-lam) # Standard formula for larger lambda
434+
)
435+
436+
# Ensure p is in valid range for geometric distribution
437+
p = np.clip(p, 1e-5, 1.)
438+
432439
t = rng.geometric(p)
433-
434440
return np.array([t])
435441

436442
for index in np.ndindex(*size):

tests/distributions/test_discrete.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,29 @@ class TestRandomVariable(BaseTestDistributionRandom):
222222
]
223223

224224
def test_random_basic_properties(self):
225+
# Test standard parameter values
225226
discrete_random_tester(
226227
dist=self.pymc_dist,
227-
paramdomains={"r": Rplus, "alpha": Rplus},
228+
paramdomains={
229+
"r": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
230+
"alpha": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
231+
},
228232
ref_rand=lambda r, alpha, size: np.random.geometric(
229233
1 - np.exp(-np.random.gamma(r, 1/alpha, size=size)), size=size
230234
),
231235
)
236+
237+
# Test small parameter values that could generate small lambda values
238+
discrete_random_tester(
239+
dist=self.pymc_dist,
240+
paramdomains={
241+
"r": Domain([0.01, 0.1], edges=(None, None)), # Small r values
242+
"alpha": Domain([10.0, 100.0], edges=(None, None)), # Large alpha values
243+
},
244+
ref_rand=lambda r, alpha, size: np.random.geometric(
245+
np.clip(np.random.gamma(r, 1/alpha, size=size), 1e-5, 1.0), size=size
246+
),
247+
)
232248

233249
@pytest.mark.parametrize("r,alpha", [
234250
(0.5, 1.0),

0 commit comments

Comments
 (0)