@@ -572,10 +572,14 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
572
572
573
573
def C_t (t : pt .TensorVariable , time_covariate_vector : pt .TensorVariable ) -> pt .TensorVariable :
574
574
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
575
- # Ensure t is a valid index
576
- t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
577
- # If t_idx exceeds length of time_covariate_vector, use last value
578
- max_idx = pt .shape (time_covariate_vector )[0 ] - 1
579
- safe_idx = pt .minimum (t_idx , max_idx )
580
- covariate_value = time_covariate_vector [..., safe_idx ]
581
- return pt .exp (covariate_value ).sum (axis = 0 )
575
+ if time_covariate_vector .ndim == 0 :
576
+ # Reshape time_covariate_vector to length t
577
+ return pt .full ((t ,), time_covariate_vector )
578
+ else :
579
+ # Ensure t is a valid index
580
+ t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
581
+ # If t_idx exceeds length of time_covariate_vector, use last value
582
+ max_idx = pt .shape (time_covariate_vector )[0 ] - 1
583
+ safe_idx = pt .minimum (t_idx , max_idx )
584
+ covariate_value = time_covariate_vector [..., safe_idx ]
585
+ return pt .exp (covariate_value ).sum ()
0 commit comments