Skip to content

Commit 0d1dcea

Browse files
committed
WIP time covars required param
1 parent ab45a9c commit 0d1dcea

File tree

2 files changed

+21
-51
lines changed

2 files changed

+21
-51
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,7 @@ class GrassiaIIGeometricRV(RandomVariable):
412412
@classmethod
413413
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
414414
# Aggregate time covariates for each sample before broadcasting
415-
exp_time_covar = np.exp(
416-
time_covariate_vector.sum(axis=0)
417-
) # TODO: try np.exp(time_covariate_vector).sum(axis=0) instead?
415+
exp_time_covar = np.exp(time_covariate_vector).sum(axis=0)
418416

419417
# Determine output size
420418
if size is None:
@@ -427,7 +425,7 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
427425

428426
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
429427

430-
lam_covar = lam * exp_time_covar # TODO: test summing over this in a notebook as well?
428+
lam_covar = lam * exp_time_covar
431429

432430
p = 1 - np.exp(-lam_covar)
433431
samples = rng.geometric(p)
@@ -483,8 +481,8 @@ class GrassiaIIGeometric(Discrete):
483481
Shape parameter (r > 0).
484482
alpha : tensor_like of float
485483
Scale parameter (alpha > 0).
486-
time_covariate_vector : tensor_like of float, optional
487-
Optional vector containing dot products of time-varying covariates and coefficients.
484+
time_covariate_vector : tensor_like of float
485+
Vector containing dot products of time-varying covariates and coefficients.
488486
489487
References
490488
----------
@@ -496,11 +494,9 @@ class GrassiaIIGeometric(Discrete):
496494
rv_op = g2g
497495

498496
@classmethod
499-
def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
497+
def dist(cls, r, alpha, time_covariate_vector, *args, **kwargs):
500498
r = pt.as_tensor_variable(r)
501499
alpha = pt.as_tensor_variable(alpha)
502-
if time_covariate_vector is None:
503-
time_covariate_vector = pt.constant(0.0)
504500
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
505501
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)
506502

@@ -537,7 +533,6 @@ def logcdf(value, r, alpha, time_covariate_vector):
537533
logcdf,
538534
r > 0,
539535
alpha > 0,
540-
time_covariate_vector >= 0,
541536
msg="r > 0, alpha > 0",
542537
)
543538

@@ -575,16 +570,12 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
575570
return mean
576571

577572

578-
# TODO: can this be moved into logp? Indexing not required for logcdf
579573
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
580574
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
581-
if time_covariate_vector.ndim == 0:
582-
return t
583-
else:
584-
# Ensure t is a valid index
585-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
586-
# If t_idx exceeds length of time_covariate_vector, use last value
587-
max_idx = pt.shape(time_covariate_vector)[0] - 1
588-
safe_idx = pt.minimum(t_idx, max_idx)
589-
covariate_value = time_covariate_vector[..., safe_idx]
590-
return pt.exp(covariate_value).sum()
575+
# Ensure t is a valid index
576+
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
577+
# If t_idx exceeds length of time_covariate_vector, use last value
578+
max_idx = pt.shape(time_covariate_vector)[0] - 1
579+
safe_idx = pt.minimum(t_idx, max_idx)
580+
covariate_value = time_covariate_vector[..., safe_idx]
581+
return pt.exp(covariate_value).sum(axis=0)

tests/distributions/test_discrete.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ def test_logp(self):
217217
class TestGrassiaIIGeometric:
218218
class TestRandomVariable(BaseTestDistributionRandom):
219219
pymc_dist = GrassiaIIGeometric
220-
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [0.0]}
221-
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [0.0]}
220+
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [1.0, 2.0, 3.0]}
221+
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariate_vector": [1.0, 2.0, 3.0]}
222222
tests_to_run = [
223223
"check_pymc_params_match_rv_op",
224224
"check_rv_size",
@@ -250,7 +250,7 @@ def test_random_edge_cases(self):
250250
# Test with small r and large alpha values
251251
r_vals = [0.1, 0.5]
252252
alpha_vals = [5.0, 10.0]
253-
time_cov_vals = [[0.0], [1.0]]
253+
time_cov_vals = [[0.0, 1.0, 2.0], [5.0, 10.0, 15.0]]
254254

255255
for r in r_vals:
256256
for alpha in alpha_vals:
@@ -266,34 +266,13 @@ def test_random_edge_cases(self):
266266
assert np.mean(draws) > 0
267267
assert np.var(draws) > 0
268268

269-
def test_random_none_covariates(self):
270-
"""Test random sampling with None time_covariate_vector"""
271-
r_vals = [0.5, 1.0, 2.0]
272-
alpha_vals = [0.5, 1.0, 2.0]
273-
274-
for r in r_vals:
275-
for alpha in alpha_vals:
276-
dist = self.pymc_dist.dist(
277-
r=r,
278-
alpha=alpha,
279-
time_covariate_vector=[0.0], # Changed from None to avoid zip issues
280-
size=1000,
281-
)
282-
draws = dist.eval()
283-
284-
# Check basic properties
285-
assert np.all(draws > 0)
286-
assert np.all(draws.astype(int) == draws)
287-
assert np.mean(draws) > 0
288-
assert np.var(draws) > 0
289-
290269
@pytest.mark.parametrize(
291270
"r,alpha,time_covariate_vector",
292271
[
293-
(0.5, 1.0, None),
272+
(0.5, 1.0, [[0.0], [0.0], [0.0]]),
294273
(1.0, 2.0, [1.0]),
295274
(2.0, 0.5, [[1.0], [2.0]]),
296-
([5.0], [1.0], None),
275+
([5.0], [1.0], [0.0, 0.0, 0.0]),
297276
],
298277
)
299278
def test_random_moments(self, r, alpha, time_covariate_vector):
@@ -311,7 +290,7 @@ def test_logp(self):
311290
# Create PyTensor variables with explicit values to ensure proper initialization
312291
r = pt.as_tensor_variable(1.0)
313292
alpha = pt.as_tensor_variable(2.0)
314-
time_covariate_vector = pt.as_tensor_variable([0.5, 1.0])
293+
time_covariate_vector = pt.as_tensor_variable([[0.5, 1.0, 1.5], [0.0, 0.0, 0.0]])
315294
value = pt.vector("value", dtype="int64")
316295

317296
# Create the distribution with the PyTensor variables
@@ -335,16 +314,16 @@ def test_logp(self):
335314
def test_logcdf(self):
336315
# test logcdf matches log sums across parameter values
337316
check_selfconsistency_discrete_logcdf(
338-
GrassiaIIGeometric, NatBig, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": Rplus}
317+
GrassiaIIGeometric, NatBig, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": I}
339318
)
340319

341320
@pytest.mark.parametrize(
342321
"r, alpha, time_covariate_vector, size, expected_shape",
343322
[
344-
(1.0, 1.0, None, None, ()), # Scalar output with no covariates
323+
(1.0, 1.0, [0.0, 0.0, 0.0], None, ()), # Scalar output
345324
([1.0, 2.0], 1.0, [0.0], None, (2,)), # Vector output from r
346325
(1.0, [1.0, 2.0], [0.0], None, (2,)), # Vector output from alpha
347-
(1.0, 1.0, [1.0, 2.0], None, ()), # Vector output from time covariates
326+
(1.0, 1.0, [[1.0, 2.0], [3.0, 4.0]], None, (2,)), # Vector output from time covariates
348327
(1.0, 1.0, [1.0, 2.0], (3, 2), (3, 2)), # Explicit size with time covariates
349328
],
350329
)

0 commit comments

Comments
 (0)