Skip to content

Commit 0b05f28

Browse files
committed
WIP fix C_t broadcasting
1 parent c66c8a6 commit 0b05f28

File tree

2 files changed

+123
-30
lines changed

2 files changed

+123
-30
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,12 @@ class GrassiaIIGeometricRV(RandomVariable):
412412
@classmethod
413413
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
414414
# Aggregate time covariates for each sample before broadcasting
415-
exp_time_covar = np.exp(time_covariate_vector).sum(axis=0)
415+
time_cov = np.asarray(time_covariate_vector)
416+
if np.ndim(time_cov) == 0:
417+
exp_time_covar = np.asarray(1.0)
418+
else:
419+
# Collapse all time/feature axes to a scalar multiplier for RNG
420+
exp_time_covar = np.asarray(np.exp(time_cov).sum())
416421

417422
# Determine output size
418423
if size is None:
@@ -500,24 +505,29 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
500505

501506
if time_covariate_vector is None:
502507
time_covariate_vector = pt.constant(0.0)
508+
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
509+
# Normalize covariate to be 1D over time
510+
if time_covariate_vector.ndim == 0:
511+
time_covariate_vector = pt.reshape(time_covariate_vector, (1,))
512+
elif time_covariate_vector.ndim > 1:
513+
feature_axes = tuple(range(time_covariate_vector.ndim - 1))
514+
time_covariate_vector = pt.sum(time_covariate_vector, axis=feature_axes)
503515

504516
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)
505517

506518
def logp(value, r, alpha, time_covariate_vector):
507-
logp = pt.log(
508-
pt.pow(alpha / (alpha + C_t(value - 1, time_covariate_vector)), r)
509-
- pt.pow(alpha / (alpha + C_t(value, time_covariate_vector)), r)
510-
)
511-
512-
# Handle invalid values
513-
logp = pt.switch(
514-
pt.or_(
515-
value < 1, # Value must be >= 1
516-
pt.isnan(logp), # Handle NaN cases
517-
),
518-
-np.inf,
519-
logp,
520-
)
519+
v = pt.as_tensor_variable(value)
520+
ct_prev = C_t(v - 1, time_covariate_vector)
521+
ct_curr = C_t(v, time_covariate_vector)
522+
logS_prev = r * (pt.log(alpha) - pt.log(alpha + ct_prev))
523+
logS_curr = r * (pt.log(alpha) - pt.log(alpha + ct_curr))
524+
# Compute log(exp(logS_prev) - exp(logS_curr)) stably
525+
max_logS = pt.maximum(logS_prev, logS_curr)
526+
diff = pt.exp(logS_prev - max_logS) - pt.exp(logS_curr - max_logS)
527+
logp = max_logS + pt.log(diff)
528+
529+
# Handle invalid / out-of-domain values
530+
logp = pt.switch(value < 1, -np.inf, logp)
521531

522532
return check_parameters(
523533
logp,
@@ -527,9 +537,15 @@ def logp(value, r, alpha, time_covariate_vector):
527537
)
528538

529539
def logcdf(value, r, alpha, time_covariate_vector):
530-
logcdf = r * (
531-
pt.log(C_t(value, time_covariate_vector))
532-
- pt.log(alpha + C_t(value, time_covariate_vector))
540+
# Log CDF: log(1 - (alpha / (alpha + C(t)))**r)
541+
t = pt.as_tensor_variable(value)
542+
ct = C_t(t, time_covariate_vector)
543+
logS = r * (pt.log(alpha) - pt.log(alpha + ct))
544+
# Numerically stable log(1 - exp(logS))
545+
logcdf = pt.switch(
546+
pt.lt(logS, np.log(0.5)),
547+
pt.log1p(-pt.exp(logS)),
548+
pt.log(-pt.expm1(logS)),
533549
)
534550

535551
return check_parameters(
@@ -561,7 +577,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
561577
)
562578

563579
# Apply time covariates if provided
564-
mean = mean * pt.exp(time_covariate_vector.sum(axis=0))
580+
tcv = pt.as_tensor_variable(time_covariate_vector)
581+
if tcv.ndim != 0:
582+
# If 1D, treat as per-time vector; if 2D+, sum features while preserving time axis
583+
cov_time = tcv if tcv.ndim == 1 else tcv.sum(axis=0)
584+
mean = mean * pt.exp(cov_time)
565585

566586
# Round up to nearest integer and ensure >= 1
567587
mean = pt.maximum(pt.ceil(mean), 1.0)
@@ -575,14 +595,31 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
575595

576596
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
577597
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
598+
# If unspecified (scalar), simply return t
578599
if time_covariate_vector.ndim == 0:
579-
# Reshape time_covariate_vector to length t
580-
return pt.full((t,), time_covariate_vector)
600+
return t
601+
602+
# Sum exp(covariates) across feature axes, keep last axis as time
603+
if time_covariate_vector.ndim == 1:
604+
per_time_sum = pt.exp(time_covariate_vector)
581605
else:
582-
# Ensure t is a valid index
583-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
584-
# If t_idx exceeds length of time_covariate_vector, use last value
585-
max_idx = pt.shape(time_covariate_vector)[0] - 1
586-
safe_idx = pt.minimum(t_idx, max_idx)
587-
covariate_value = time_covariate_vector[..., safe_idx]
588-
return pt.exp(covariate_value).sum()
606+
feature_axes = tuple(range(time_covariate_vector.ndim - 1))
607+
per_time_sum = pt.sum(pt.exp(time_covariate_vector), axis=feature_axes)
608+
609+
# Build cumulative sum up to each t without advanced indexing
610+
time_length = pt.shape(per_time_sum)[0]
611+
# Ensure t is at least 1D int64 for broadcasting
612+
t_vec = pt.cast(t, "int64")
613+
t_vec = pt.shape_padleft(t_vec) if t_vec.ndim == 0 else t_vec
614+
# Create time indices [0, 1, ..., T-1]
615+
time_idx = pt.arange(time_length, dtype="int64")
616+
# Mask where time index < t (exclusive upper bound)
617+
mask = pt.lt(time_idx, pt.shape_padright(t_vec, 1))
618+
# Sum per-time contributions over time axis
619+
base_sum = pt.sum(pt.shape_padleft(per_time_sum) * mask, axis=-1)
620+
# Carry-forward last per-time value for t beyond time_length
621+
last_value = per_time_sum[-1]
622+
excess_steps = pt.maximum(t_vec - time_length, 0)
623+
carried = base_sum + excess_steps * last_value
624+
# If original t was scalar, return scalar
625+
return pt.squeeze(carried)

tests/distributions/test_discrete.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
BaseTestDistributionRandom,
2525
Domain,
2626
I,
27-
NatBig,
27+
PosNat,
2828
Rplus,
2929
assert_support_point_is_expected,
3030
check_logp,
@@ -314,7 +314,7 @@ def test_logp(self):
314314
def test_logcdf(self):
315315
# test logcdf matches log sums across parameter values
316316
check_selfconsistency_discrete_logcdf(
317-
GrassiaIIGeometric, NatBig, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": I}
317+
GrassiaIIGeometric, PosNat, {"r": Rplus, "alpha": Rplus, "time_covariate_vector": I}
318318
)
319319

320320
@pytest.mark.parametrize(
@@ -349,3 +349,59 @@ def test_support_point(self, r, alpha, time_covariate_vector, size, expected_sha
349349

350350
# TODO: expected values must be provided
351351
# assert_support_point_is_expected(model, init_point)
352+
353+
def test_C_t_unspecified_returns_t(self):
354+
# When unspecified is represented as a scalar 0.0, C_t should return t
355+
from pymc_extras.distributions.discrete import C_t
356+
357+
t = pt.vector("t", dtype="int64")
358+
cov = pt.as_tensor_variable(0.0)
359+
fn = pytensor.function([t], C_t(t, cov))
360+
test_t = np.array([0, 1, 2, 3, 10], dtype="int64")
361+
np.testing.assert_array_equal(fn(test_t), test_t)
362+
363+
def test_C_t_1d_vector_sum_up_to_t_with_saturation(self):
364+
# For a 1D time_covariate_vector, C_t should sum exp up to t (exclusive upper bound),
365+
# and saturate when t exceeds the length
366+
from pymc_extras.distributions.discrete import C_t
367+
368+
t = pt.vector("t", dtype="int64")
369+
cov = pt.as_tensor_variable(np.array([0.0, 1.0, -1.0], dtype=float)) # length 3
370+
fn = pytensor.function([t], C_t(t, cov))
371+
test_t = np.array([0, 1, 2, 3, 4], dtype="int64")
372+
per_time = np.exp(np.array([0.0, 1.0, -1.0]))
373+
csum = np.cumsum(per_time)
374+
expected = []
375+
for tt in test_t:
376+
if tt <= 0:
377+
expected.append(0.0)
378+
elif tt >= len(per_time):
379+
expected.append(csum[-1])
380+
else:
381+
expected.append(csum[tt - 1])
382+
expected = np.array(expected)
383+
np.testing.assert_allclose(fn(test_t), expected)
384+
385+
def test_C_t_2d_features_by_time_sum_up_to_t_with_saturation(self):
386+
# For a 2D (features x time) covariate, sum features first then cumulative over time
387+
# and saturate when t exceeds the length
388+
from pymc_extras.distributions.discrete import C_t
389+
390+
t = pt.vector("t", dtype="int64")
391+
cov = pt.as_tensor_variable(
392+
np.array([[0.5, 1.0, 1.5], [0.0, 0.0, 0.0]], dtype=float)
393+
) # 2x3
394+
fn = pytensor.function([t], C_t(t, cov))
395+
test_t = np.array([0, 1, 2, 3, 4], dtype="int64")
396+
per_time = np.sum(np.exp(np.array([[0.5, 1.0, 1.5], [0.0, 0.0, 0.0]])), axis=0)
397+
csum = np.cumsum(per_time)
398+
expected = []
399+
for tt in test_t:
400+
if tt <= 0:
401+
expected.append(0.0)
402+
elif tt >= len(per_time):
403+
expected.append(csum[-1])
404+
else:
405+
expected.append(csum[tt - 1])
406+
expected = np.array(expected)
407+
np.testing.assert_allclose(fn(test_t), expected)

0 commit comments

Comments
 (0)