@@ -409,7 +409,7 @@ def dist(cls, mu1, mu2, **kwargs):
409
409
class GrassiaIIGeometricRV (RandomVariable ):
410
410
name = "g2g"
411
411
signature = "(),(),()->()"
412
- ndims_params = [0 , 0 , 0 ] # r, alpha, time_covariate_vector are all scalars
412
+ ndims_params = [0 , 0 , 0 ] # deprecated in PyTensor 2.31.7, but still required for RandomVariable
413
413
414
414
dtype = "int64"
415
415
_print_name = ("GrassiaIIGeometric" , "\\ operatorname{GrassiaIIGeometric}" )
@@ -430,14 +430,13 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
430
430
alpha = np .broadcast_to (alpha , size )
431
431
time_covariate_vector = np .broadcast_to (time_covariate_vector , size )
432
432
433
- # Calculate exp(time_covariate_vector) for all samples
434
- exp_time_covar = np .exp (time_covariate_vector )
435
-
436
- # Generate gamma samples and apply time covariates
437
433
lam = rng .gamma (shape = r , scale = 1 / alpha , size = size )
438
434
439
- # TODO: Add C(t) to the calculation of lam_covar
435
+ # Calculate exp(time_covariate_vector) for all samples
436
+ exp_time_covar = np .exp (time_covariate_vector )
440
437
lam_covar = lam * exp_time_covar
438
+
439
+ # TODO: This is not aggregated over time
441
440
p = 1 - np .exp (- lam_covar )
442
441
443
442
# Ensure p is in valid range for geometric distribution
@@ -526,16 +525,18 @@ def logp(value, r, alpha, time_covariate_vector=None):
526
525
time_covariate_vector = pt .as_tensor_variable (time_covariate_vector )
527
526
528
527
def C_t (t ):
529
- # Aggregate time_covariate_vector over active time periods
530
528
if t == 0 :
531
529
return pt .constant (0.0 )
532
- # Handle case where time_covariate_vector is a scalar
533
530
if time_covariate_vector .ndim == 0 :
534
- return t * pt . exp ( time_covariate_vector )
531
+ return t
535
532
else :
536
- # For time covariates, this approximation avoids symbolic indexing issues
537
- mean_covariate = pt .mean (time_covariate_vector )
538
- return t * pt .exp (mean_covariate )
533
+ # Ensure t is a valid index
534
+ t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
535
+ # If t_idx exceeds length of time_covariate_vector, use last value
536
+ max_idx = pt .shape (time_covariate_vector )[0 ] - 1
537
+ safe_idx = pt .minimum (t_idx , max_idx )
538
+ covariate_value = time_covariate_vector [safe_idx ]
539
+ return t * pt .exp (covariate_value )
539
540
540
541
logp = pt .log (
541
542
pt .pow (alpha / (alpha + C_t (value - 1 )), r ) - pt .pow (alpha / (alpha + C_t (value )), r )
@@ -567,11 +568,15 @@ def C_t(t):
567
568
if t == 0 :
568
569
return pt .constant (0.0 )
569
570
if time_covariate_vector .ndim == 0 :
570
- return t * pt . exp ( time_covariate_vector )
571
+ return t
571
572
else :
572
- # For time covariates, this approximation avoids symbolic indexing issues
573
- mean_covariate = pt .mean (time_covariate_vector )
574
- return t * pt .exp (mean_covariate )
573
+ # Ensure t is a valid index
574
+ t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
575
+ # If t_idx exceeds length of time_covariate_vector, use last value
576
+ max_idx = pt .shape (time_covariate_vector )[0 ] - 1
577
+ safe_idx = pt .minimum (t_idx , max_idx )
578
+ covariate_value = time_covariate_vector [safe_idx ]
579
+ return t * pt .exp (covariate_value )
575
580
576
581
survival = pt .pow (alpha / (alpha + C_t (value )), r )
577
582
logcdf = pt .log (1 - survival )
0 commit comments