@@ -402,46 +402,48 @@ def dist(cls, mu1, mu2, **kwargs):
402
402
403
403
class GrassiaIIGeometricRV (RandomVariable ):
404
404
name = "g2g"
405
- signature = "(),()->()"
405
+ signature = "(),(),() ->()"
406
406
407
407
dtype = "int64"
408
408
_print_name = ("GrassiaIIGeometric" , "\\ operatorname{GrassiaIIGeometric}" )
409
409
410
- def __call__ (self , r , alpha , size = None , ** kwargs ):
411
- return super ().__call__ (r , alpha , size = size , ** kwargs )
410
+ def __call__ (self , r , alpha , time_covar_dot = None , size = None , ** kwargs ):
411
+ return super ().__call__ (r , alpha , time_covar_dot = time_covar_dot , size = size , ** kwargs )
412
412
413
- # TODO: Param will need to be added for dot product of time-varying covariates
414
413
@classmethod
415
- def rng_fn (cls , rng , r , alpha , size ):
414
+ def rng_fn (cls , rng , r , alpha , time_covar_dot ,size ):
415
+ if time_covar_dot is None :
416
+ time_covar_dot = np .array (0 )
416
417
if size is None :
417
- size = np .broadcast_shapes (r .shape , alpha .shape )
418
-
419
- r = np .asarray (r )
420
- alpha = np .asarray (alpha )
418
+ size = np .broadcast_shapes (r .shape , alpha .shape , time_covar_dot .shape )
421
419
422
420
r = np .broadcast_to (r , size )
423
421
alpha = np .broadcast_to (alpha , size )
422
+ time_covar_dot = np .broadcast_to (time_covar_dot ,size )
424
423
425
424
output = np .zeros (shape = size + (1 ,)) # noqa:RUF005
426
425
427
426
lam = rng .gamma (shape = r , scale = 1 / alpha , size = size )
428
427
429
- def sim_data (lam ):
428
+ exp_time_covar_dot = np .exp (time_covar_dot )
429
+
430
+ def sim_data (lam , exp_time_covar_dot ):
430
431
# Handle numerical stability for very small lambda values
431
- p = np .where (
432
- lam < 0.001 ,
433
- lam , # For small lambda, p ≈ lambda
434
- 1 - np . exp ( - lam ) # TODO: covariate param added here as 1 - np.exp(-lam * np.expcovar_dot )
435
- )
432
+ # p = np.where(
433
+ # lam < 0.0001 ,
434
+ # lam, # For small lambda, p ≈ lambda
435
+ # 1 - np.exp(-lam * exp_time_covar_dot )
436
+ # )
436
437
437
- # Ensure p is in valid range for geometric distribution
438
- p = np .clip (p , 1e-5 , 1. )
438
+ # Ensure lam is in valid range for geometric distribution
439
+ lam = np .clip (lam , np .finfo (float ).tiny , 1. )
440
+ p = 1 - np .exp (- lam * exp_time_covar_dot )
439
441
440
442
t = rng .geometric (p )
441
443
return np .array ([t ])
442
444
443
445
for index in np .ndindex (* size ):
444
- output [index ] = sim_data (lam [index ])
446
+ output [index ] = sim_data (lam [index ], exp_time_covar_dot [ index ] )
445
447
446
448
return output
447
449
0 commit comments