Skip to content

Commit 78be107

Browse files
committed
WIP add covar support to RV
1 parent bcd9cac commit 78be107

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -402,46 +402,48 @@ def dist(cls, mu1, mu2, **kwargs):
402402

403403
class GrassiaIIGeometricRV(RandomVariable):
404404
name = "g2g"
405-
signature = "(),()->()"
405+
signature = "(),(),()->()"
406406

407407
dtype = "int64"
408408
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
409409

410-
def __call__(self, r, alpha, size=None, **kwargs):
411-
return super().__call__(r, alpha, size=size, **kwargs)
410+
def __call__(self, r, alpha, time_covar_dot=None,size=None, **kwargs):
411+
return super().__call__(r, alpha, time_covar_dot=time_covar_dot, size=size, **kwargs)
412412

413-
# TODO: Param will need to be added for dot product of time-varying covariates
414413
@classmethod
415-
def rng_fn(cls, rng, r, alpha, size):
414+
def rng_fn(cls, rng, r, alpha, time_covar_dot,size):
415+
if time_covar_dot is None:
416+
time_covar_dot = np.array(0)
416417
if size is None:
417-
size = np.broadcast_shapes(r.shape, alpha.shape)
418-
419-
r = np.asarray(r)
420-
alpha = np.asarray(alpha)
418+
size = np.broadcast_shapes(r.shape, alpha.shape, time_covar_dot.shape)
421419

422420
r = np.broadcast_to(r, size)
423421
alpha = np.broadcast_to(alpha, size)
422+
time_covar_dot = np.broadcast_to(time_covar_dot,size)
424423

425424
output = np.zeros(shape=size + (1,)) # noqa:RUF005
426425

427426
lam = rng.gamma(shape=r, scale=1/alpha, size=size)
428427

429-
def sim_data(lam):
428+
exp_time_covar_dot = np.exp(time_covar_dot)
429+
430+
def sim_data(lam, exp_time_covar_dot):
430431
# Handle numerical stability for very small lambda values
431-
p = np.where(
432-
lam < 0.001,
433-
lam, # For small lambda, p ≈ lambda
434-
1 - np.exp(-lam) # TODO: covariate param added here as 1 - np.exp(-lam * np.expcovar_dot)
435-
)
432+
# p = np.where(
433+
# lam < 0.0001,
434+
# lam, # For small lambda, p ≈ lambda
435+
# 1 - np.exp(-lam * exp_time_covar_dot)
436+
# )
436437

437-
# Ensure p is in valid range for geometric distribution
438-
p = np.clip(p, 1e-5, 1.)
438+
# Ensure lam is in valid range for geometric distribution
439+
lam = np.clip(lam, np.finfo(float).tiny, 1.)
440+
p = 1 - np.exp(-lam * exp_time_covar_dot)
439441

440442
t = rng.geometric(p)
441443
return np.array([t])
442444

443445
for index in np.ndindex(*size):
444-
output[index] = sim_data(lam[index])
446+
output[index] = sim_data(lam[index], exp_time_covar_dot[index])
445447

446448
return output
447449

0 commit comments

Comments
 (0)