Skip to content

Commit e3731eb

Browse files
committed
WIP tests
1 parent d002ca8 commit e3731eb

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,11 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
433433
lam_covar = lam * exp_time_covar
434434

435435
p = 1 - np.exp(-lam_covar)
436+
# TODO: This is a hack to ensure valid probability in (0, 1]
437+
# We should find a better way to do this.
438+
# Ensure valid probability in (0, 1]
439+
tiny = np.finfo(p.dtype).tiny
440+
p = np.clip(p, tiny, 1.0)
436441
samples = rng.geometric(p)
437442
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
438443

@@ -576,12 +581,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
576581
1.0 / (1.0 - pt.exp(-base_lambda)), # Full expression for larger lambda
577582
)
578583

579-
# Apply time covariates if provided
584+
# Apply time covariates if provided: multiply by exp(sum over axis=0)
585+
# This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
580586
tcv = pt.as_tensor_variable(time_covariate_vector)
581587
if tcv.ndim != 0:
582-
# If 1D, treat as per-time vector; if 2D+, sum features while preserving time axis
583-
cov_time = tcv if tcv.ndim == 1 else tcv.sum(axis=0)
584-
mean = mean * pt.exp(cov_time)
588+
mean = mean * pt.exp(tcv.sum(axis=0))
585589

586590
# Round up to nearest integer and ensure >= 1
587591
mean = pt.maximum(pt.ceil(mean), 1.0)
@@ -603,8 +607,8 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
603607
if time_covariate_vector.ndim == 1:
604608
per_time_sum = pt.exp(time_covariate_vector)
605609
else:
606-
feature_axes = tuple(range(time_covariate_vector.ndim - 1))
607-
per_time_sum = pt.sum(pt.exp(time_covariate_vector), axis=feature_axes)
610+
# If axis=0 is time and axis>0 are features, sum over features (axis>0)
611+
per_time_sum = pt.sum(pt.exp(time_covariate_vector), axis=0)
608612

609613
# Build cumulative sum up to each t without advanced indexing
610614
time_length = pt.shape(per_time_sum)[0]
@@ -617,9 +621,5 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
617621
mask = pt.lt(time_idx, pt.shape_padright(t_vec, 1))
618622
# Sum per-time contributions over time axis
619623
base_sum = pt.sum(pt.shape_padleft(per_time_sum) * mask, axis=-1)
620-
# Carry-forward last per-time value for t beyond time_length
621-
last_value = per_time_sum[-1]
622-
excess_steps = pt.maximum(t_vec - time_length, 0)
623-
carried = base_sum + excess_steps * last_value
624-
# If original t was scalar, return scalar
625-
return pt.squeeze(carried)
624+
# If original t was scalar, return scalar (saturate at last time step)
625+
return pt.squeeze(base_sum)

0 commit comments

Comments
 (0)