Skip to content

Commit 0fa3390

Browse files
committed
clean up cursor code
1 parent 05e7c55 commit 0fa3390

File tree

1 file changed

+12
-35
lines changed

1 file changed

+12
-35
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -410,15 +410,14 @@ class GrassiaIIGeometricRV(RandomVariable):
410410
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
411411

412412
def __call__(self, r, alpha, time_covariate_vector=None, size=None, **kwargs):
413-
return super().__call__(r, alpha, time_covariate_vector, size=size, **kwargs)
413+
return super().__call__(r, alpha, time_covariate_vector, size, **kwargs)
414414

415415
@classmethod
416416
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
417-
# Handle None case for time_covariate_vector
418417
if time_covariate_vector is None:
419418
time_covariate_vector = 0.0
420419

421-
# Convert inputs to numpy arrays - these should be concrete values
420+
# Cast inputs as numpy arrays
422421
r = np.asarray(r, dtype=np.float64)
423422
alpha = np.asarray(alpha, dtype=np.float64)
424423
time_covariate_vector = np.asarray(time_covariate_vector, dtype=np.float64)
@@ -427,33 +426,28 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
427426
if size is None:
428427
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariate_vector.shape)
429428

430-
# Broadcast parameters to the output size
429+
# Broadcast parameters to output size
431430
r = np.broadcast_to(r, size)
432431
alpha = np.broadcast_to(alpha, size)
433432
time_covariate_vector = np.broadcast_to(time_covariate_vector, size)
434433

435434
# Calculate exp(time_covariate_vector) for all samples
436-
exp_time_covar_sum = np.exp(time_covariate_vector)
435+
exp_time_covar = np.exp(time_covariate_vector)
437436

438437
# Generate gamma samples and apply time covariates
439438
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
440-
lam_covar = lam * exp_time_covar_sum
441439

442-
# Calculate probability parameter for geometric distribution
443-
# Use the mathematically correct approach: 1 - exp(-lambda)
444-
# This matches the first test case and is theoretically sound
440+
# TODO: Add C(t) to the calculation of lam_covar
441+
lam_covar = lam * exp_time_covar
445442
p = 1 - np.exp(-lam_covar)
446443

447444
# Ensure p is in valid range for geometric distribution
448-
# Use a more conservative lower bound to prevent extremely large values
449445
min_p = max(1e-6, np.finfo(float).tiny) # Minimum probability to prevent infinite values
450446
p = np.clip(p, min_p, 1.0)
451447

452-
# Generate geometric samples
453448
samples = rng.geometric(p)
454449

455450
# Clip samples to reasonable bounds to prevent infinite values
456-
# Geometric distribution with small p can produce very large values
457451
max_sample = 10000 # Reasonable upper bound for discrete time-to-event data
458452
samples = np.clip(samples, 1, max_sample)
459453

@@ -507,7 +501,7 @@ class GrassiaIIGeometric(Discrete):
507501
alpha : tensor_like of float
508502
Scale parameter (alpha > 0).
509503
time_covariate_vector : tensor_like of float, optional
510-
Optional vector of dot product of time-varying covariates and their coefficients by time period.
504+
Optional vector containing dot products of time-varying covariates and coefficients.
511505
512506
References
513507
----------
@@ -535,19 +529,15 @@ def logp(value, r, alpha, time_covariate_vector=None):
535529
def C_t(t):
536530
# Aggregate time_covariate_vector over active time periods
537531
if t == 0:
538-
return pt.constant(1.0)
532+
return pt.constant(0.0)
539533
# Handle case where time_covariate_vector is a scalar
540534
if time_covariate_vector.ndim == 0:
541535
return t * pt.exp(time_covariate_vector)
542536
else:
543-
# For vector time_covariate_vector, use a simpler approach
544-
# that works with PyTensor's symbolic system
545-
# We'll use the mean of the time covariates multiplied by t
546-
# This is an approximation but avoids symbolic indexing issues
537+
# For time covariates, this approximation avoids symbolic indexing issues
547538
mean_covariate = pt.mean(time_covariate_vector)
548539
return t * pt.exp(mean_covariate)
549540

550-
# Calculate the PMF on log scale
551541
logp = pt.log(
552542
pt.pow(alpha / (alpha + C_t(value - 1)), r) - pt.pow(alpha / (alpha + C_t(value)), r)
553543
)
@@ -574,21 +564,13 @@ def logcdf(value, r, alpha, time_covariate_vector=None):
574564
time_covariate_vector = pt.constant(0.0)
575565
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
576566

577-
# Calculate CDF on log scale
578-
# For the GrassiaIIGeometric, the CDF is 1 - survival function
579-
# S(t) = (alpha/(alpha + C(t)))^r
580-
# CDF(t) = 1 - S(t)
581-
582567
def C_t(t):
583568
if t == 0:
584569
return pt.constant(1.0)
585570
if time_covariate_vector.ndim == 0:
586571
return t * pt.exp(time_covariate_vector)
587572
else:
588-
# For vector time_covariate_vector, use a simpler approach
589-
# that works with PyTensor's symbolic system
590-
# We'll use the mean of the time covariates multiplied by t
591-
# This is an approximation but avoids symbolic indexing issues
573+
# For time covariates, this approximation avoids symbolic indexing issues
592574
mean_covariate = pt.mean(time_covariate_vector)
593575
return t * pt.exp(mean_covariate)
594576

@@ -613,14 +595,9 @@ def support_point(rv, size, r, alpha, time_covariate_vector=None):
613595
When time_covariate_vector is provided, it affects the expected value through
614596
the exponential link function: exp(time_covariate_vector).
615597
"""
616-
# Base mean from the gamma mixing distribution: E[lambda] = r/alpha
617-
# For a geometric distribution with parameter p, E[X] = 1/p
618-
# Since p = 1 - exp(-lambda), we approximate E[X] ≈ 1/(1 - exp(-E[lambda]))
619598
base_lambda = r / alpha
620599

621-
# Approximate the expected value of the geometric distribution
622-
# For small lambda, 1 - exp(-lambda) ≈ lambda, so E[X] ≈ 1/lambda
623-
# For larger lambda, we use the full expression
600+
# Approximate expected value of geometric distribution
624601
mean = pt.switch(
625602
base_lambda < 0.1,
626603
1.0 / base_lambda, # Approximation for small lambda
@@ -631,7 +608,7 @@ def support_point(rv, size, r, alpha, time_covariate_vector=None):
631608
if time_covariate_vector is not None:
632609
mean = mean * pt.exp(time_covariate_vector)
633610

634-
# Round up to nearest integer and ensure it's at least 1
611+
# Round up to nearest integer and ensure >= 1
635612
mean = pt.maximum(pt.ceil(mean), 1.0)
636613

637614
# Handle size parameter

0 commit comments

Comments
 (0)