Skip to content

Commit 9803321

Browse files
committed
rng_fn cleanup
1 parent b34e3d8 commit 9803321

File tree

1 file changed

+5
-13
lines changed

1 file changed

+5
-13
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -423,28 +423,20 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
423423
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
424424

425425
# Calculate exp(time_covariate_vector) for all samples
426-
exp_time_covar = np.exp(time_covariate_vector)
426+
exp_time_covar = np.exp(
427+
time_covariate_vector
428+
).mean() # must average over time for correct broadcasting
427429
lam_covar = lam * exp_time_covar
428430

429-
# TODO: Derive inverse log_cdf and use rng.uniform or rng.log_uniform
430-
p = 1 - np.exp(-lam_covar)
431-
432-
# Ensure p is in valid range for geometric distribution
433-
min_p = max(1e-6, np.finfo(float).tiny) # Minimum probability to prevent infinite values
434-
p = np.clip(p, min_p, 1.0)
435-
436-
samples = rng.geometric(p)
437-
438-
# Clip samples to reasonable bounds to prevent infinite values
439-
max_sample = 10000 # Reasonable upper bound for discrete time-to-event data
440-
samples = np.clip(samples, 1, max_sample)
431+
samples = np.ceil(rng.exponential(size=size) / lam_covar)
441432

442433
return samples
443434

444435

445436
g2g = GrassiaIIGeometricRV()
446437

447438

439+
# TODO: Add covariate expressions to docstrings.
448440
class GrassiaIIGeometric(Discrete):
449441
r"""Grassia(II)-Geometric distribution.
450442

0 commit comments

Comments
 (0)