Skip to content

Commit b34e3d8

Browse files
committed
make C_t external function, code cleanup
1 parent eb7222f commit b34e3d8

File tree

2 files changed

+20
-39
lines changed

2 files changed

+20
-39
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -510,24 +510,9 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
510510
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)
511511

512512
def logp(value, r, alpha, time_covariate_vector):
513-
if time_covariate_vector is None:
514-
time_covariate_vector = pt.constant(0.0)
515-
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
516-
517-
def C_t(t):
518-
if time_covariate_vector.ndim == 0:
519-
return t
520-
else:
521-
# Ensure t is a valid index
522-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
523-
# If t_idx exceeds length of time_covariate_vector, use last value
524-
max_idx = pt.shape(time_covariate_vector)[0] - 1
525-
safe_idx = pt.minimum(t_idx, max_idx)
526-
covariate_value = time_covariate_vector[safe_idx]
527-
return t * pt.exp(covariate_value)
528-
529513
logp = pt.log(
530-
pt.pow(alpha / (alpha + C_t(value - 1)), r) - pt.pow(alpha / (alpha + C_t(value)), r)
514+
pt.pow(alpha / (alpha + C_t(value - 1, time_covariate_vector)), r)
515+
- pt.pow(alpha / (alpha + C_t(value, time_covariate_vector)), r)
531516
)
532517

533518
# Handle invalid values
@@ -548,24 +533,10 @@ def C_t(t):
548533
)
549534

550535
def logcdf(value, r, alpha, time_covariate_vector):
551-
if time_covariate_vector is None:
552-
time_covariate_vector = pt.constant(0.0)
553-
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
554-
555-
def C_t(t):
556-
if time_covariate_vector.ndim == 0:
557-
return t
558-
else:
559-
# Ensure t is a valid index
560-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
561-
# If t_idx exceeds length of time_covariate_vector, use last value
562-
max_idx = pt.shape(time_covariate_vector)[0] - 1
563-
safe_idx = pt.minimum(t_idx, max_idx)
564-
covariate_value = time_covariate_vector[safe_idx]
565-
return t * pt.exp(covariate_value)
566-
567-
survival = pt.pow(alpha / (alpha + C_t(value)), r)
568-
logcdf = pt.log(1 - survival)
536+
logcdf = r * (
537+
pt.log(C_t(value, time_covariate_vector))
538+
- pt.log(alpha + C_t(value, time_covariate_vector))
539+
)
569540

570541
return check_parameters(
571542
logcdf,
@@ -585,8 +556,6 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
585556
When time_covariate_vector is provided, it affects the expected value through
586557
the exponential link function: exp(time_covariate_vector).
587558
"""
588-
if time_covariate_vector is None:
589-
time_covariate_vector = pt.constant(0.0)
590559

591560
base_lambda = r / alpha
592561

@@ -608,3 +577,17 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
608577
mean = pt.full(size, mean)
609578

610579
return mean
580+
581+
582+
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
583+
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
584+
if time_covariate_vector.ndim == 0:
585+
return t
586+
else:
587+
# Ensure t is a valid index
588+
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
589+
# If t_idx exceeds length of time_covariate_vector, use last value
590+
max_idx = pt.shape(time_covariate_vector)[0] - 1
591+
safe_idx = pt.minimum(t_idx, max_idx)
592+
covariate_value = time_covariate_vector[safe_idx]
593+
return t * pt.exp(covariate_value)

tests/distributions/test_discrete.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
Rplus,
2828
assert_support_point_is_expected,
2929
check_logp,
30-
check_logcdf,
31-
check_support_point,
3230
discrete_random_tester,
3331
)
3432
from pytensor import config

0 commit comments

Comments
 (0)