Skip to content

Commit a715ec7

Browse files
committed
clean up comments and final TODO
1 parent 5baa6f7 commit a715ec7

File tree

2 files changed

+26
-32
lines changed

2 files changed

+26
-32
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ def dist(cls, mu1, mu2, **kwargs):
409409
class GrassiaIIGeometricRV(RandomVariable):
410410
name = "g2g"
411411
signature = "(),(),()->()"
412-
ndims_params = [0, 0, 0] # r, alpha, time_covariate_vector are all scalars
412+
ndims_params = [0, 0, 0] # deprecated in PyTensor 2.31.7, but still required for RandomVariable
413413

414414
dtype = "int64"
415415
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
@@ -430,14 +430,13 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
430430
alpha = np.broadcast_to(alpha, size)
431431
time_covariate_vector = np.broadcast_to(time_covariate_vector, size)
432432

433-
# Calculate exp(time_covariate_vector) for all samples
434-
exp_time_covar = np.exp(time_covariate_vector)
435-
436-
# Generate gamma samples and apply time covariates
437433
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
438434

439-
# TODO: Add C(t) to the calculation of lam_covar
435+
# Calculate exp(time_covariate_vector) for all samples
436+
exp_time_covar = np.exp(time_covariate_vector)
440437
lam_covar = lam * exp_time_covar
438+
439+
# TODO: This is not aggregated over time
441440
p = 1 - np.exp(-lam_covar)
442441

443442
# Ensure p is in valid range for geometric distribution
@@ -526,16 +525,18 @@ def logp(value, r, alpha, time_covariate_vector=None):
526525
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
527526

528527
def C_t(t):
529-
# Aggregate time_covariate_vector over active time periods
530528
if t == 0:
531529
return pt.constant(0.0)
532-
# Handle case where time_covariate_vector is a scalar
533530
if time_covariate_vector.ndim == 0:
534-
return t * pt.exp(time_covariate_vector)
531+
return t
535532
else:
536-
# For time covariates, this approximation avoids symbolic indexing issues
537-
mean_covariate = pt.mean(time_covariate_vector)
538-
return t * pt.exp(mean_covariate)
533+
# Ensure t is a valid index
534+
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
535+
# If t_idx exceeds length of time_covariate_vector, use last value
536+
max_idx = pt.shape(time_covariate_vector)[0] - 1
537+
safe_idx = pt.minimum(t_idx, max_idx)
538+
covariate_value = time_covariate_vector[safe_idx]
539+
return t * pt.exp(covariate_value)
539540

540541
logp = pt.log(
541542
pt.pow(alpha / (alpha + C_t(value - 1)), r) - pt.pow(alpha / (alpha + C_t(value)), r)
@@ -567,11 +568,15 @@ def C_t(t):
567568
if t == 0:
568569
return pt.constant(0.0)
569570
if time_covariate_vector.ndim == 0:
570-
return t * pt.exp(time_covariate_vector)
571+
return t
571572
else:
572-
# For time covariates, this approximation avoids symbolic indexing issues
573-
mean_covariate = pt.mean(time_covariate_vector)
574-
return t * pt.exp(mean_covariate)
573+
# Ensure t is a valid index
574+
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
575+
# If t_idx exceeds length of time_covariate_vector, use last value
576+
max_idx = pt.shape(time_covariate_vector)[0] - 1
577+
safe_idx = pt.minimum(t_idx, max_idx)
578+
covariate_value = time_covariate_vector[safe_idx]
579+
return t * pt.exp(covariate_value)
575580

576581
survival = pt.pow(alpha / (alpha + C_t(value)), r)
577582
logcdf = pt.log(1 - survival)

tests/distributions/test_discrete.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ def test_random_none_covariates(self):
272272
dist = self.pymc_dist.dist(
273273
r=r,
274274
alpha=alpha,
275-
time_covariate_vector=0.0,
276-
size=1000, # Changed from None to 0.0
275+
time_covariate_vector=0.0, # Changed from None to avoid zip issues
276+
size=1000,
277277
)
278278
draws = dist.eval()
279279

@@ -289,7 +289,7 @@ def test_random_none_covariates(self):
289289
(0.5, 1.0, 0.0),
290290
(1.0, 2.0, 1.0),
291291
(2.0, 0.5, -1.0),
292-
(5.0, 1.0, 0.0), # Changed from None to 0.0 to avoid zip issues
292+
(5.0, 1.0, 0.0), # Changed from None to avoid zip issues
293293
],
294294
)
295295
def test_random_moments(self, r, alpha, time_covariate_vector):
@@ -298,13 +298,8 @@ def test_random_moments(self, r, alpha, time_covariate_vector):
298298
)
299299
draws = dist.eval()
300300

301-
# Check that all values are positive integers
302301
assert np.all(draws > 0)
303302
assert np.all(draws.astype(int) == draws)
304-
305-
# Check that values are reasonably distributed
306-
# Note: Exact moments are complex for this distribution
307-
# so we just check basic properties
308303
assert np.mean(draws) > 0
309304
assert np.var(draws) > 0
310305

@@ -337,21 +332,18 @@ def test_sampling_consistency(self):
337332
"""Test that sampling from the distribution produces reasonable results"""
338333
r = 2.0
339334
alpha = 1.0
340-
time_covariate_vector = 0.0 # Changed from None to 0.0 to avoid issues
335+
time_covariate_vector = [0.0, 1.0, 2.0]
341336

342-
# First test direct sampling from the distribution
343337
try:
344338
dist = GrassiaIIGeometric.dist(
345339
r=r, alpha=alpha, time_covariate_vector=time_covariate_vector
346340
)
347341

348342
direct_samples = dist.eval()
349343

350-
# Convert to numpy array if it's not already
351344
if not isinstance(direct_samples, np.ndarray):
352345
direct_samples = np.array([direct_samples])
353346

354-
# Ensure we have a 1D array
355347
if direct_samples.ndim == 0:
356348
direct_samples = direct_samples.reshape(1)
357349

@@ -371,7 +363,6 @@ def test_sampling_consistency(self):
371363
traceback.print_exc()
372364
raise
373365

374-
# Then test MCMC sampling
375366
try:
376367
with pm.Model():
377368
x = GrassiaIIGeometric(
@@ -382,7 +373,7 @@ def test_sampling_consistency(self):
382373
chains=1, draws=50, tune=0, random_seed=42, progressbar=False
383374
).posterior
384375

385-
# Extract samples and ensure they're in the correct shape
376+
# Extract samples and ensure correct shape
386377
samples = trace["x"].values
387378

388379
assert (
@@ -415,9 +406,7 @@ def test_sampling_consistency(self):
415406
), f"Variance {var} is not in valid range for {time_covariate_vector}"
416407

417408
# Additional checks for distribution properties
418-
# The mean should be greater than 1 for these parameters
419409
assert mean > 1, f"Mean {mean} is not greater than 1 for {time_covariate_vector}"
420-
# The variance should be positive and finite
421410
assert var > 0, f"Variance {var} is not positive for {time_covariate_vector}"
422411

423412
except Exception as e:

0 commit comments

Comments
 (0)