Skip to content

Commit 932a046

Browse files
committed
covariate pos constraint and WIP RV
1 parent 63a0b10 commit 932a046

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def dist(cls, mu1, mu2, **kwargs):
404404

405405
class GrassiaIIGeometricRV(RandomVariable):
406406
name = "g2g"
407-
signature = "(),(),()->()"
407+
signature = "(),(),(t)->()"
408408

409409
dtype = "int64"
410410
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
@@ -422,15 +422,13 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
422422

423423
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
424424

425-
# Calculate exp(time_covariate_vector) for all samples
425+
# Aggregate time covariates for each sample
426426
exp_time_covar = np.exp(
427-
time_covariate_vector
428-
).mean() # Approximation required to return a t-scalar from a covariate vector
427+
time_covariate_vector.sum(axis=0)
428+
) # TODO: try np.exp(time_covariate_vector).sum(axis=0) instead?
429429
lam_covar = lam * exp_time_covar
430430

431-
# Take uniform draws from the inverse CDF
432-
u = rng.uniform(size=size)
433-
samples = np.ceil(np.log(1 - u) / (-lam_covar))
431+
samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
434432

435433
return samples
436434

@@ -536,6 +534,7 @@ def logcdf(value, r, alpha, time_covariate_vector):
536534
logcdf,
537535
r > 0,
538536
alpha > 0,
537+
time_covariate_vector >= 0,
539538
msg="r > 0, alpha > 0",
540539
)
541540

@@ -573,6 +572,7 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
573572
return mean
574573

575574

575+
# TODO: can this be moved into logp? Indexing not required for logcdf
576576
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
577577
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
578578
if time_covariate_vector.ndim == 0:

0 commit comments

Comments
 (0)