Skip to content

Commit 63a0b10

Browse files
committed
inverse cdf
1 parent 5ff6853 commit 63a0b10

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,10 +425,12 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
425425
# Calculate exp(time_covariate_vector) for all samples
426426
exp_time_covar = np.exp(
427427
time_covariate_vector
428-
).mean() # must average over time for correct broadcasting
428+
).mean() # Approximation required to return a t-scalar from a covariate vector
429429
lam_covar = lam * exp_time_covar
430430

431-
samples = np.ceil(rng.exponential(size=size) / lam_covar)
431+
# Take uniform draws from the inverse CDF
432+
u = rng.uniform(size=size)
433+
samples = np.ceil(np.log(1 - u) / (-lam_covar))
432434

433435
return samples
434436

@@ -581,5 +583,5 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
581583
# If t_idx exceeds length of time_covariate_vector, use last value
582584
max_idx = pt.shape(time_covariate_vector)[0] - 1
583585
safe_idx = pt.minimum(t_idx, max_idx)
584-
covariate_value = time_covariate_vector[safe_idx]
585-
return t * pt.exp(covariate_value)
586+
covariate_value = time_covariate_vector[..., safe_idx]
587+
return pt.exp(covariate_value).sum()

tests/distributions/test_discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def test_random_edge_cases(self):
249249
# Test with small r and large alpha values
250250
r_vals = [0.1, 0.5]
251251
alpha_vals = [5.0, 10.0]
252-
time_cov_vals = [0.0, 1.0]
252+
time_cov_vals = [[0.0], [1.0]]
253253

254254
for r in r_vals:
255255
for alpha in alpha_vals:

0 commit comments

Comments
 (0)