Skip to content

Commit b957333

Browse files
committed
WIP symbolic indexing
1 parent fa9c1ec commit b957333

File tree

3 files changed

+270
-151
lines changed

3 files changed

+270
-151
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 91 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -401,70 +401,63 @@ def dist(cls, mu1, mu2, **kwargs):
401401
**kwargs,
402402
)
403403

404-
# TODO: C expressions are not correct. Both value and covariate broadcasting must be handled.
404+
405405
class GrassiaIIGeometricRV(RandomVariable):
406406
name = "g2g"
407407
signature = "(),(),()->()"
408408

409409
dtype = "int64"
410410
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
411411

412-
def __call__(self, r, alpha, time_covariates_sum=None, size=None, **kwargs):
413-
return super().__call__(r, alpha, time_covariates_sum, size=size, **kwargs)
412+
def __call__(self, r, alpha, time_covariate_vector=None, size=None, **kwargs):
413+
return super().__call__(r, alpha, time_covariate_vector, size=size, **kwargs)
414414

415415
@classmethod
416-
def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
417-
if time_covariates_sum is None:
418-
time_covariates_sum = np.array(0)
416+
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
417+
# Handle None case for time_covariate_vector
418+
if time_covariate_vector is None:
419+
time_covariate_vector = 0.0
420+
421+
# Convert inputs to numpy arrays - these should be concrete values
422+
r = np.asarray(r, dtype=np.float64)
423+
alpha = np.asarray(alpha, dtype=np.float64)
424+
time_covariate_vector = np.asarray(time_covariate_vector, dtype=np.float64)
425+
426+
# Determine output size
419427
if size is None:
420-
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariates_sum.shape)
428+
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariate_vector.shape)
421429

430+
# Broadcast parameters to the output size
422431
r = np.broadcast_to(r, size)
423432
alpha = np.broadcast_to(alpha, size)
424-
time_covariates_sum = np.broadcast_to(time_covariates_sum, size)
425-
426-
# Calculate exp(time_covariates_sum) for all samples
427-
exp_time_covar_sum = np.exp(time_covariates_sum)
428-
429-
# Initialize output array
430-
output = np.zeros(size, dtype=np.int64)
431-
432-
# For each sample, generate a value from the distribution
433-
for idx in np.ndindex(*size):
434-
# Calculate survival probabilities for each possible value
435-
t = 1
436-
while True:
437-
C_t = t + exp_time_covar_sum[idx]
438-
C_tm1 = (t - 1) + exp_time_covar_sum[idx]
439-
440-
# Calculate PMF for current t
441-
pmf = (
442-
(alpha[idx] / (alpha[idx] + C_tm1)) ** r[idx] -
443-
(alpha[idx] / (alpha[idx] + C_t)) ** r[idx]
444-
)
433+
time_covariate_vector = np.broadcast_to(time_covariate_vector, size)
445434

446-
# If PMF is negative or NaN, we've gone too far
447-
if pmf <= 0 or np.isnan(pmf):
448-
break
435+
# Calculate exp(time_covariate_vector) for all samples
436+
exp_time_covar_sum = np.exp(time_covariate_vector)
449437

450-
# Accept this value with probability proportional to PMF
451-
if rng.random() < pmf:
452-
output[idx] = t
453-
break
438+
# Use a simpler approach: generate from a geometric distribution with transformed parameters
439+
# This is an approximation but should be much faster and more reliable
440+
lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
441+
lam_covar = lam * exp_time_covar_sum
454442

455-
t += 1
443+
# Handle numerical stability for very small lambda values
444+
p = np.where(
445+
lam_covar < 0.0001,
446+
lam_covar, # For small values, set this to p
447+
1 - np.exp(-lam_covar),
448+
)
456449

457-
# Safety check to prevent infinite loops
458-
if t > 1000: # Arbitrary large number
459-
output[idx] = t
460-
break
450+
# Ensure p is in valid range for geometric distribution
451+
p = np.clip(p, np.finfo(float).tiny, 1.0)
461452

462-
return output
453+
# Generate geometric samples
454+
return rng.geometric(p)
463455

464456

465457
g2g = GrassiaIIGeometricRV()
466458

467-
# TODO: C expressions are not correct. Both value and covariate broadcasting must be handled.
459+
460+
class GrassiaIIGeometric(Discrete):
468461
r"""Grassia(II)-Geometric distribution.
469462
470463
This distribution is a flexible alternative to the Geometric distribution for the number of trials until a
@@ -507,8 +500,8 @@ def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
507500
Shape parameter (r > 0).
508501
alpha : tensor_like of float
509502
Scale parameter (alpha > 0).
510-
time_covariates_sum : tensor_like of float, optional
511-
Optional dot product of time-varying covariates and their coefficients, summed over time.
503+
time_covariate_vector : tensor_like of float, optional
504+
Optional vector of dot product of time-varying covariates and their coefficients by time period.
512505
513506
References
514507
----------
@@ -520,34 +513,36 @@ def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
520513
rv_op = g2g
521514

522515
@classmethod
523-
def dist(cls, r, alpha, time_covariates_sum=None, *args, **kwargs):
516+
def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
524517
r = pt.as_tensor_variable(r)
525518
alpha = pt.as_tensor_variable(alpha)
526-
if time_covariates_sum is None:
527-
time_covariates_sum = pt.constant(0.0)
528-
time_covariates_sum = pt.as_tensor_variable(time_covariates_sum)
529-
return super().dist([r, alpha, time_covariates_sum], *args, **kwargs)
530-
531-
def logp(value, r, alpha, time_covariates_sum=None):
532-
"""
533-
Log probability function for GrassiaIIGeometric distribution.
534-
535-
The PMF is:
536-
P(T=t|r,α,β;Z(t)) = (α/(α+C(t-1)))^r - (α/(α+C(t)))^r
537-
538-
where C(t) = t + exp(time_covariates_sum)
539-
"""
540-
if time_covariates_sum is None:
541-
time_covariates_sum = pt.constant(0.0)
542-
543-
# Calculate C(t) and C(t-1)
544-
C_t = value + pt.exp(time_covariates_sum)
545-
C_tm1 = (value - 1) + pt.exp(time_covariates_sum)
519+
if time_covariate_vector is None:
520+
time_covariate_vector = pt.constant(0.0)
521+
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
522+
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)
523+
524+
def logp(value, r, alpha, time_covariate_vector=None):
525+
if time_covariate_vector is None:
526+
time_covariate_vector = pt.constant(0.0)
527+
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
528+
529+
def C_t(t):
530+
# Aggregate time_covariate_vector over active time periods
531+
if t == 0:
532+
return pt.constant(1.0)
533+
# Handle case where time_covariate_vector is a scalar
534+
if time_covariate_vector.ndim == 0:
535+
return t * pt.exp(time_covariate_vector)
536+
else:
537+
# For vector time_covariate_vector, we need to handle symbolic indexing
538+
# Since we can't slice with symbolic indices, we'll use a different approach
539+
# For now, we'll use the first element multiplied by t
540+
# This is a simplification but should work for basic cases
541+
return t * pt.exp(time_covariate_vector[:t])
546542

547543
# Calculate the PMF on log scale
548544
logp = pt.log(
549-
pt.pow(alpha / (alpha + C_tm1), r) -
550-
pt.pow(alpha / (alpha + C_t), r)
545+
pt.pow(alpha / (alpha + C_t(value - 1)), r) - pt.pow(alpha / (alpha + C_t(value)), r)
551546
)
552547

553548
# Handle invalid values
@@ -557,7 +552,7 @@ def logp(value, r, alpha, time_covariates_sum=None):
557552
pt.isnan(logp), # Handle NaN cases
558553
),
559554
-np.inf,
560-
logp
555+
logp,
561556
)
562557

563558
return check_parameters(
@@ -567,36 +562,52 @@ def logp(value, r, alpha, time_covariates_sum=None):
567562
msg="r > 0, alpha > 0",
568563
)
569564

570-
def logcdf(value, r, alpha, time_covariates_sum=None):
571-
if time_covariates_sum is not None:
572-
value = time_covariates_sum
573-
logcdf = r * (pt.log(value) - pt.log(alpha + value))
565+
def logcdf(value, r, alpha, time_covariate_vector=None):
566+
if time_covariate_vector is None:
567+
time_covariate_vector = pt.constant(0.0)
568+
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
569+
570+
# Calculate CDF on log scale
571+
# For the GrassiaIIGeometric, the CDF is 1 - survival function
572+
# S(t) = (alpha/(alpha + C(t)))^r
573+
# CDF(t) = 1 - S(t)
574+
575+
def C_t(t):
576+
if t == 0:
577+
return pt.constant(1.0)
578+
if time_covariate_vector.ndim == 0:
579+
return t * pt.exp(time_covariate_vector)
580+
else:
581+
return t * pt.exp(time_covariate_vector[:t])
582+
583+
survival = pt.pow(alpha / (alpha + C_t(value)), r)
584+
logcdf = pt.log(1 - survival)
574585

575586
return check_parameters(
576587
logcdf,
577588
r > 0,
578-
alpha > 0, # alpha must be greater than 0.6181 for convergence
589+
alpha > 0,
579590
msg="r > 0, alpha > 0",
580591
)
581592

582-
def support_point(rv, size, r, alpha, time_covariates_sum=None):
593+
def support_point(rv, size, r, alpha, time_covariate_vector=None):
583594
"""Calculate a reasonable starting point for sampling.
584595
585596
For the GrassiaIIGeometric distribution, we use a point estimate based on
586597
the expected value of the mixing distribution. Since the mixing distribution
587598
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
588599
the geometric link function and round to ensure an integer value.
589600
590-
When time_covariates_sum is provided, it affects the expected value through
591-
the exponential link function: exp(time_covariates_sum).
601+
When time_covariate_vector is provided, it affects the expected value through
602+
the exponential link function: exp(time_covariate_vector).
592603
"""
593604
# Base mean without covariates
594-
mean = pt.exp(alpha/r)
605+
mean = pt.exp(alpha / r)
595606

596607
# Apply time-varying covariates if provided
597-
if time_covariates_sum is None:
598-
time_covariates_sum = pt.constant(0.0)
599-
mean = mean * pt.exp(time_covariates_sum)
608+
if time_covariate_vector is None:
609+
time_covariate_vector = pt.constant(0.0)
610+
mean = mean * pt.exp(time_covariate_vector)
600611

601612
# Round up to nearest integer
602613
mean = pt.ceil(mean)

test_simple.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
3+
import pymc as pm
4+
import pytensor.tensor as pt
5+
6+
from pymc_extras.distributions import GrassiaIIGeometric
7+
8+
9+
def test_basic_functionality():
10+
"""Test basic functionality of GrassiaIIGeometric distribution"""
11+
print("Testing basic GrassiaIIGeometric functionality...")
12+
13+
# Test 1: Create distribution with None time_covariate_vector
14+
try:
15+
dist = GrassiaIIGeometric.dist(r=2.0, alpha=1.0, time_covariate_vector=None)
16+
print("✓ Distribution created successfully with None time_covariate_vector")
17+
18+
# Test sampling
19+
samples = dist.eval()
20+
print(f"✓ Direct sampling successful: {samples}")
21+
22+
except Exception as e:
23+
print(f"✗ Failed to create distribution with None time_covariate_vector: {e}")
24+
return False
25+
26+
# Test 2: Create distribution with scalar time_covariate_vector
27+
try:
28+
dist = GrassiaIIGeometric.dist(r=2.0, alpha=1.0, time_covariate_vector=0.5)
29+
print("✓ Distribution created successfully with scalar time_covariate_vector")
30+
31+
# Test sampling
32+
samples = dist.eval()
33+
print(f"✓ Direct sampling successful: {samples}")
34+
35+
except Exception as e:
36+
print(f"✗ Failed to create distribution with scalar time_covariate_vector: {e}")
37+
return False
38+
39+
# Test 3: Test logp function
40+
try:
41+
r = pt.scalar("r")
42+
alpha = pt.scalar("alpha")
43+
time_covariate_vector = pt.scalar("time_covariate_vector")
44+
value = pt.scalar("value", dtype="int64")
45+
46+
logp = pm.logp(GrassiaIIGeometric.dist(r, alpha, time_covariate_vector), value)
47+
logp_fn = pm.compile_fn([value, r, alpha, time_covariate_vector], logp)
48+
49+
result = logp_fn(2, 1.0, 1.0, 0.0)
50+
print(f"✓ Logp function works: {result}")
51+
52+
except Exception as e:
53+
print(f"✗ Failed to test logp function: {e}")
54+
return False
55+
56+
print("✓ All basic functionality tests passed!")
57+
return True
58+
59+
60+
if __name__ == "__main__":
61+
test_basic_functionality()

0 commit comments

Comments
 (0)