Skip to content

Commit ab45a9c

Browse files
committed
WIP rng_fn testing
1 parent b78a5c4 commit ab45a9c

File tree

2 files changed

+29
-25
lines changed

2 files changed

+29
-25
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,24 +411,27 @@ class GrassiaIIGeometricRV(RandomVariable):
411411

412412
@classmethod
413413
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
414+
# Aggregate time covariates for each sample before broadcasting
415+
exp_time_covar = np.exp(
416+
time_covariate_vector.sum(axis=0)
417+
) # TODO: try np.exp(time_covariate_vector).sum(axis=0) instead?
418+
414419
# Determine output size
415420
if size is None:
416-
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariate_vector.shape)
421+
size = np.broadcast_shapes(r.shape, alpha.shape, exp_time_covar.shape)
417422

418423
# Broadcast parameters to output size
419424
r = np.broadcast_to(r, size)
420425
alpha = np.broadcast_to(alpha, size)
421-
time_covariate_vector = np.broadcast_to(time_covariate_vector, size)
426+
exp_time_covar = np.broadcast_to(exp_time_covar, size)
422427

423428
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
424429

425-
# Aggregate time covariates for each sample
426-
exp_time_covar = np.exp(
427-
time_covariate_vector.sum(axis=0)
428-
) # TODO: try np.exp(time_covariate_vector).sum(axis=0) instead?
429-
lam_covar = lam * exp_time_covar
430+
lam_covar = lam * exp_time_covar # TODO: test summing over this in a notebook as well?
430431

431-
samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
432+
p = 1 - np.exp(-lam_covar)
433+
samples = rng.geometric(p)
434+
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
432435

433436
return samples
434437

@@ -560,7 +563,7 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
560563
)
561564

562565
# Apply time covariates if provided
563-
mean = mean * pt.exp(time_covariate_vector)
566+
mean = mean * pt.exp(time_covariate_vector.sum(axis=0))
564567

565568
# Round up to nearest integer and ensure >= 1
566569
mean = pt.maximum(pt.ceil(mean), 1.0)

tests/distributions/test_discrete.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
BaseTestDistributionRandom,
2525
Domain,
2626
I,
27+
NatBig,
2728
Rplus,
2829
assert_support_point_is_expected,
2930
check_logp,
@@ -216,8 +217,8 @@ def test_logp(self):
216217
class TestGrassiaIIGeometric:
217218
class TestRandomVariable(BaseTestDistributionRandom):
218219
pymc_dist = GrassiaIIGeometric
219-
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": 0.0}
220-
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": 0.0}
220+
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [0.0]}
221+
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [0.0]}
221222
tests_to_run = [
222223
"check_pymc_params_match_rv_op",
223224
"check_rv_size",
@@ -228,7 +229,7 @@ def test_random_basic_properties(self):
228229
# Test with standard parameter values
229230
r_vals = [0.5, 1.0, 2.0]
230231
alpha_vals = [0.5, 1.0, 2.0]
231-
time_cov_vals = [-1.0, 1.0, 2.0]
232+
time_cov_vals = [[0.0], [1.0], [2.0]]
232233

233234
for r in r_vals:
234235
for alpha in alpha_vals:
@@ -275,7 +276,7 @@ def test_random_none_covariates(self):
275276
dist = self.pymc_dist.dist(
276277
r=r,
277278
alpha=alpha,
278-
time_covariate_vector=0.0, # Changed from None to avoid zip issues
279+
time_covariate_vector=[0.0], # Changed from None to avoid zip issues
279280
size=1000,
280281
)
281282
draws = dist.eval()
@@ -289,10 +290,10 @@ def test_random_none_covariates(self):
289290
@pytest.mark.parametrize(
290291
"r,alpha,time_covariate_vector",
291292
[
292-
(0.5, 1.0, 0.0),
293-
(1.0, 2.0, 1.0),
294-
(2.0, 0.5, -1.0),
295-
(5.0, 1.0, 0.0), # Changed from None to avoid zip issues
293+
(0.5, 1.0, None),
294+
(1.0, 2.0, [1.0]),
295+
(2.0, 0.5, [[1.0], [2.0]]),
296+
([5.0], [1.0], None),
296297
],
297298
)
298299
def test_random_moments(self, r, alpha, time_covariate_vector):
@@ -306,11 +307,11 @@ def test_random_moments(self, r, alpha, time_covariate_vector):
306307
assert np.mean(draws) > 0
307308
assert np.var(draws) > 0
308309

309-
def test_logp_basic(self):
310+
def test_logp(self):
310311
# Create PyTensor variables with explicit values to ensure proper initialization
311312
r = pt.as_tensor_variable(1.0)
312313
alpha = pt.as_tensor_variable(2.0)
313-
time_covariate_vector = pt.as_tensor_variable(0.5)
314+
time_covariate_vector = pt.as_tensor_variable([0.5, 1.0])
314315
value = pt.vector("value", dtype="int64")
315316

316317
# Create the distribution with the PyTensor variables
@@ -334,17 +335,17 @@ def test_logp_basic(self):
334335
def test_logcdf(self):
335336
# test logcdf matches log sums across parameter values
336337
check_selfconsistency_discrete_logcdf(
337-
GrassiaIIGeometric, I, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": I}
338+
GrassiaIIGeometric, NatBig, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": Rplus}
338339
)
339340

340341
@pytest.mark.parametrize(
341342
"r, alpha, time_covariate_vector, size, expected_shape",
342343
[
343-
(1.0, 1.0, 0.0, None, ()), # Scalar output with no covariates (0.0 instead of None)
344-
([1.0, 2.0], 1.0, 0.0, None, (2,)), # Vector output from r
345-
(1.0, [1.0, 2.0], 0.0, None, (2,)), # Vector output from alpha
346-
(1.0, 1.0, [1.0, 2.0], None, (2,)), # Vector output from time covariates
347-
(1.0, 1.0, 1.0, (3, 2), (3, 2)), # Explicit size with scalar time covariates
344+
(1.0, 1.0, None, None, ()), # Scalar output with no covariates
345+
([1.0, 2.0], 1.0, [0.0], None, (2,)), # Vector output from r
346+
(1.0, [1.0, 2.0], [0.0], None, (2,)), # Vector output from alpha
347+
(1.0, 1.0, [1.0, 2.0], None, ()), # Vector output from time covariates
348+
(1.0, 1.0, [1.0, 2.0], (3, 2), (3, 2)), # Explicit size with time covariates
348349
],
349350
)
350351
def test_support_point(self, r, alpha, time_covariate_vector, size, expected_shape):

0 commit comments

Comments
 (0)