@@ -412,7 +412,12 @@ class GrassiaIIGeometricRV(RandomVariable):
412
412
@classmethod
413
413
def rng_fn (cls , rng , r , alpha , time_covariate_vector , size ):
414
414
# Aggregate time covariates for each sample before broadcasting
415
- exp_time_covar = np .exp (time_covariate_vector ).sum (axis = 0 )
415
+ time_cov = np .asarray (time_covariate_vector )
416
+ if np .ndim (time_cov ) == 0 :
417
+ exp_time_covar = np .asarray (1.0 )
418
+ else :
419
+ # Collapse all time/feature axes to a scalar multiplier for RNG
420
+ exp_time_covar = np .asarray (np .exp (time_cov ).sum ())
416
421
417
422
# Determine output size
418
423
if size is None :
@@ -500,24 +505,29 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
500
505
501
506
if time_covariate_vector is None :
502
507
time_covariate_vector = pt .constant (0.0 )
508
+ time_covariate_vector = pt .as_tensor_variable (time_covariate_vector )
509
+ # Normalize covariate to be 1D over time
510
+ if time_covariate_vector .ndim == 0 :
511
+ time_covariate_vector = pt .reshape (time_covariate_vector , (1 ,))
512
+ elif time_covariate_vector .ndim > 1 :
513
+ feature_axes = tuple (range (time_covariate_vector .ndim - 1 ))
514
+ time_covariate_vector = pt .sum (time_covariate_vector , axis = feature_axes )
503
515
504
516
return super ().dist ([r , alpha , time_covariate_vector ], * args , ** kwargs )
505
517
506
518
def logp (value , r , alpha , time_covariate_vector ):
507
- logp = pt .log (
508
- pt .pow (alpha / (alpha + C_t (value - 1 , time_covariate_vector )), r )
509
- - pt .pow (alpha / (alpha + C_t (value , time_covariate_vector )), r )
510
- )
511
-
512
- # Handle invalid values
513
- logp = pt .switch (
514
- pt .or_ (
515
- value < 1 , # Value must be >= 1
516
- pt .isnan (logp ), # Handle NaN cases
517
- ),
518
- - np .inf ,
519
- logp ,
520
- )
519
+ v = pt .as_tensor_variable (value )
520
+ ct_prev = C_t (v - 1 , time_covariate_vector )
521
+ ct_curr = C_t (v , time_covariate_vector )
522
+ logS_prev = r * (pt .log (alpha ) - pt .log (alpha + ct_prev ))
523
+ logS_curr = r * (pt .log (alpha ) - pt .log (alpha + ct_curr ))
524
+ # Compute log(exp(logS_prev) - exp(logS_curr)) stably
525
+ max_logS = pt .maximum (logS_prev , logS_curr )
526
+ diff = pt .exp (logS_prev - max_logS ) - pt .exp (logS_curr - max_logS )
527
+ logp = max_logS + pt .log (diff )
528
+
529
+ # Handle invalid / out-of-domain values
530
+ logp = pt .switch (value < 1 , - np .inf , logp )
521
531
522
532
return check_parameters (
523
533
logp ,
@@ -527,9 +537,15 @@ def logp(value, r, alpha, time_covariate_vector):
527
537
)
528
538
529
539
def logcdf (value , r , alpha , time_covariate_vector ):
530
- logcdf = r * (
531
- pt .log (C_t (value , time_covariate_vector ))
532
- - pt .log (alpha + C_t (value , time_covariate_vector ))
540
+ # Log CDF: log(1 - (alpha / (alpha + C(t)))**r)
541
+ t = pt .as_tensor_variable (value )
542
+ ct = C_t (t , time_covariate_vector )
543
+ logS = r * (pt .log (alpha ) - pt .log (alpha + ct ))
544
+ # Numerically stable log(1 - exp(logS))
545
+ logcdf = pt .switch (
546
+ pt .lt (logS , np .log (0.5 )),
547
+ pt .log1p (- pt .exp (logS )),
548
+ pt .log (- pt .expm1 (logS )),
533
549
)
534
550
535
551
return check_parameters (
@@ -561,7 +577,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
561
577
)
562
578
563
579
# Apply time covariates if provided
564
- mean = mean * pt .exp (time_covariate_vector .sum (axis = 0 ))
580
+ tcv = pt .as_tensor_variable (time_covariate_vector )
581
+ if tcv .ndim != 0 :
582
+ # If 1D, treat as per-time vector; if 2D+, sum features while preserving time axis
583
+ cov_time = tcv if tcv .ndim == 1 else tcv .sum (axis = 0 )
584
+ mean = mean * pt .exp (cov_time )
565
585
566
586
# Round up to nearest integer and ensure >= 1
567
587
mean = pt .maximum (pt .ceil (mean ), 1.0 )
@@ -575,14 +595,31 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
575
595
576
596
def C_t (t : pt .TensorVariable , time_covariate_vector : pt .TensorVariable ) -> pt .TensorVariable :
577
597
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
598
+ # If unspecified (scalar), simply return t
578
599
if time_covariate_vector .ndim == 0 :
579
- # Reshape time_covariate_vector to length t
580
- return pt .full ((t ,), time_covariate_vector )
600
+ return t
601
+
602
+ # Sum exp(covariates) across feature axes, keep last axis as time
603
+ if time_covariate_vector .ndim == 1 :
604
+ per_time_sum = pt .exp (time_covariate_vector )
581
605
else :
582
- # Ensure t is a valid index
583
- t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
584
- # If t_idx exceeds length of time_covariate_vector, use last value
585
- max_idx = pt .shape (time_covariate_vector )[0 ] - 1
586
- safe_idx = pt .minimum (t_idx , max_idx )
587
- covariate_value = time_covariate_vector [..., safe_idx ]
588
- return pt .exp (covariate_value ).sum ()
606
+ feature_axes = tuple (range (time_covariate_vector .ndim - 1 ))
607
+ per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = feature_axes )
608
+
609
+ # Build cumulative sum up to each t without advanced indexing
610
+ time_length = pt .shape (per_time_sum )[0 ]
611
+ # Ensure t is at least 1D int64 for broadcasting
612
+ t_vec = pt .cast (t , "int64" )
613
+ t_vec = pt .shape_padleft (t_vec ) if t_vec .ndim == 0 else t_vec
614
+ # Create time indices [0, 1, ..., T-1]
615
+ time_idx = pt .arange (time_length , dtype = "int64" )
616
+ # Mask where time index < t (exclusive upper bound)
617
+ mask = pt .lt (time_idx , pt .shape_padright (t_vec , 1 ))
618
+ # Sum per-time contributions over time axis
619
+ base_sum = pt .sum (pt .shape_padleft (per_time_sum ) * mask , axis = - 1 )
620
+ # Carry-forward last per-time value for t beyond time_length
621
+ last_value = per_time_sum [- 1 ]
622
+ excess_steps = pt .maximum (t_vec - time_length , 0 )
623
+ carried = base_sum + excess_steps * last_value
624
+ # If original t was scalar, return scalar
625
+ return pt .squeeze (carried )
0 commit comments