@@ -510,24 +510,9 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
510
510
return super ().dist ([r , alpha , time_covariate_vector ], * args , ** kwargs )
511
511
512
512
def logp (value , r , alpha , time_covariate_vector ):
513
- if time_covariate_vector is None :
514
- time_covariate_vector = pt .constant (0.0 )
515
- time_covariate_vector = pt .as_tensor_variable (time_covariate_vector )
516
-
517
- def C_t (t ):
518
- if time_covariate_vector .ndim == 0 :
519
- return t
520
- else :
521
- # Ensure t is a valid index
522
- t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
523
- # If t_idx exceeds length of time_covariate_vector, use last value
524
- max_idx = pt .shape (time_covariate_vector )[0 ] - 1
525
- safe_idx = pt .minimum (t_idx , max_idx )
526
- covariate_value = time_covariate_vector [safe_idx ]
527
- return t * pt .exp (covariate_value )
528
-
529
513
logp = pt .log (
530
- pt .pow (alpha / (alpha + C_t (value - 1 )), r ) - pt .pow (alpha / (alpha + C_t (value )), r )
514
+ pt .pow (alpha / (alpha + C_t (value - 1 , time_covariate_vector )), r )
515
+ - pt .pow (alpha / (alpha + C_t (value , time_covariate_vector )), r )
531
516
)
532
517
533
518
# Handle invalid values
@@ -548,24 +533,10 @@ def C_t(t):
548
533
)
549
534
550
535
def logcdf (value , r , alpha , time_covariate_vector ):
551
- if time_covariate_vector is None :
552
- time_covariate_vector = pt .constant (0.0 )
553
- time_covariate_vector = pt .as_tensor_variable (time_covariate_vector )
554
-
555
- def C_t (t ):
556
- if time_covariate_vector .ndim == 0 :
557
- return t
558
- else :
559
- # Ensure t is a valid index
560
- t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
561
- # If t_idx exceeds length of time_covariate_vector, use last value
562
- max_idx = pt .shape (time_covariate_vector )[0 ] - 1
563
- safe_idx = pt .minimum (t_idx , max_idx )
564
- covariate_value = time_covariate_vector [safe_idx ]
565
- return t * pt .exp (covariate_value )
566
-
567
- survival = pt .pow (alpha / (alpha + C_t (value )), r )
568
- logcdf = pt .log (1 - survival )
536
+ logcdf = r * (
537
+ pt .log (C_t (value , time_covariate_vector ))
538
+ - pt .log (alpha + C_t (value , time_covariate_vector ))
539
+ )
569
540
570
541
return check_parameters (
571
542
logcdf ,
@@ -585,8 +556,6 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
585
556
When time_covariate_vector is provided, it affects the expected value through
586
557
the exponential link function: exp(time_covariate_vector).
587
558
"""
588
- if time_covariate_vector is None :
589
- time_covariate_vector = pt .constant (0.0 )
590
559
591
560
base_lambda = r / alpha
592
561
@@ -608,3 +577,17 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
608
577
mean = pt .full (size , mean )
609
578
610
579
return mean
580
+
581
+
582
+ def C_t (t : pt .TensorVariable , time_covariate_vector : pt .TensorVariable ) -> pt .TensorVariable :
583
+ """Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
584
+ if time_covariate_vector .ndim == 0 :
585
+ return t
586
+ else :
587
+ # Ensure t is a valid index
588
+ t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
589
+ # If t_idx exceeds length of time_covariate_vector, use last value
590
+ max_idx = pt .shape (time_covariate_vector )[0 ] - 1
591
+ safe_idx = pt .minimum (t_idx , max_idx )
592
+ covariate_value = time_covariate_vector [safe_idx ]
593
+ return t * pt .exp (covariate_value )
0 commit comments