@@ -433,6 +433,11 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
433
433
lam_covar = lam * exp_time_covar
434
434
435
435
p = 1 - np .exp (- lam_covar )
436
+ # TODO: This is a hack to ensure valid probability in (0, 1]
437
+ # We should find a better way to do this.
438
+ # Ensure valid probability in (0, 1]
439
+ tiny = np .finfo (p .dtype ).tiny
440
+ p = np .clip (p , tiny , 1.0 )
436
441
samples = rng .geometric (p )
437
442
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
438
443
@@ -576,12 +581,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
576
581
1.0 / (1.0 - pt .exp (- base_lambda )), # Full expression for larger lambda
577
582
)
578
583
579
- # Apply time covariates if provided
584
+ # Apply time covariates if provided: multiply by exp(sum over axis=0)
585
+ # This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
580
586
tcv = pt .as_tensor_variable (time_covariate_vector )
581
587
if tcv .ndim != 0 :
582
- # If 1D, treat as per-time vector; if 2D+, sum features while preserving time axis
583
- cov_time = tcv if tcv .ndim == 1 else tcv .sum (axis = 0 )
584
- mean = mean * pt .exp (cov_time )
588
+ mean = mean * pt .exp (tcv .sum (axis = 0 ))
585
589
586
590
# Round up to nearest integer and ensure >= 1
587
591
mean = pt .maximum (pt .ceil (mean ), 1.0 )
@@ -603,8 +607,8 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
603
607
if time_covariate_vector .ndim == 1 :
604
608
per_time_sum = pt .exp (time_covariate_vector )
605
609
else :
606
- feature_axes = tuple ( range ( time_covariate_vector . ndim - 1 ) )
607
- per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = feature_axes )
610
+ # If axis=0 is time and axis>0 are features, sum over features (axis>0 )
611
+ per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = 0 )
608
612
609
613
# Build cumulative sum up to each t without advanced indexing
610
614
time_length = pt .shape (per_time_sum )[0 ]
@@ -617,9 +621,5 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
617
621
mask = pt .lt (time_idx , pt .shape_padright (t_vec , 1 ))
618
622
# Sum per-time contributions over time axis
619
623
base_sum = pt .sum (pt .shape_padleft (per_time_sum ) * mask , axis = - 1 )
620
- # Carry-forward last per-time value for t beyond time_length
621
- last_value = per_time_sum [- 1 ]
622
- excess_steps = pt .maximum (t_vec - time_length , 0 )
623
- carried = base_sum + excess_steps * last_value
624
- # If original t was scalar, return scalar
625
- return pt .squeeze (carried )
624
+ # If original t was scalar, return scalar (saturate at last time step)
625
+ return pt .squeeze (base_sum )
0 commit comments