Skip to content

Commit 434e5a5

Browse files
committed
C_t for RV time covar support
1 parent 0d1dcea commit 434e5a5

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -572,10 +572,14 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
572572

573573
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
574574
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
575-
# Ensure t is a valid index
576-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
577-
# If t_idx exceeds length of time_covariate_vector, use last value
578-
max_idx = pt.shape(time_covariate_vector)[0] - 1
579-
safe_idx = pt.minimum(t_idx, max_idx)
580-
covariate_value = time_covariate_vector[..., safe_idx]
581-
return pt.exp(covariate_value).sum(axis=0)
575+
if time_covariate_vector.ndim == 0:
576+
# Reshape time_covariate_vector to length t
577+
return pt.full((t,), time_covariate_vector)
578+
else:
579+
# Ensure t is a valid index
580+
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
581+
# If t_idx exceeds length of time_covariate_vector, use last value
582+
max_idx = pt.shape(time_covariate_vector)[0] - 1
583+
safe_idx = pt.minimum(t_idx, max_idx)
584+
covariate_value = time_covariate_vector[..., safe_idx]
585+
return pt.exp(covariate_value).sum()

0 commit comments

Comments
 (0)