Skip to content

Commit 8a30459

Browse files
committed
WIP time indexing
1 parent 78be107 commit 8a30459

File tree

2 files changed

+178
-80
lines changed

2 files changed

+178
-80
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 104 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -399,59 +399,70 @@ def dist(cls, mu1, mu2, **kwargs):
399399
**kwargs,
400400
)
401401

402-
402+
# TODO: C expressions are not correct. Both value and covariate broadcasting must be handled.
403403
class GrassiaIIGeometricRV(RandomVariable):
404404
name = "g2g"
405405
signature = "(),(),()->()"
406406

407407
dtype = "int64"
408408
_print_name = ("GrassiaIIGeometric", "\\operatorname{GrassiaIIGeometric}")
409409

410-
def __call__(self, r, alpha, time_covar_dot=None,size=None, **kwargs):
411-
return super().__call__(r, alpha, time_covar_dot=time_covar_dot, size=size, **kwargs)
410+
def __call__(self, r, alpha, time_covariates_sum=None, size=None, **kwargs):
411+
return super().__call__(r, alpha, time_covariates_sum, size=size, **kwargs)
412412

413413
@classmethod
414-
def rng_fn(cls, rng, r, alpha, time_covar_dot,size):
415-
if time_covar_dot is None:
416-
time_covar_dot = np.array(0)
414+
def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
415+
if time_covariates_sum is None:
416+
time_covariates_sum = np.array(0)
417417
if size is None:
418-
size = np.broadcast_shapes(r.shape, alpha.shape, time_covar_dot.shape)
418+
size = np.broadcast_shapes(r.shape, alpha.shape, time_covariates_sum.shape)
419419

420420
r = np.broadcast_to(r, size)
421421
alpha = np.broadcast_to(alpha, size)
422-
time_covar_dot = np.broadcast_to(time_covar_dot,size)
423-
424-
output = np.zeros(shape=size + (1,)) # noqa:RUF005
425-
426-
lam = rng.gamma(shape=r, scale=1/alpha, size=size)
427-
428-
exp_time_covar_dot = np.exp(time_covar_dot)
429-
430-
def sim_data(lam, exp_time_covar_dot):
431-
# Handle numerical stability for very small lambda values
432-
# p = np.where(
433-
# lam < 0.0001,
434-
# lam, # For small lambda, p ≈ lambda
435-
# 1 - np.exp(-lam * exp_time_covar_dot)
436-
# )
437-
438-
# Ensure lam is in valid range for geometric distribution
439-
lam = np.clip(lam, np.finfo(float).tiny, 1.)
440-
p = 1 - np.exp(-lam * exp_time_covar_dot)
441-
442-
t = rng.geometric(p)
443-
return np.array([t])
444-
445-
for index in np.ndindex(*size):
446-
output[index] = sim_data(lam[index], exp_time_covar_dot[index])
447-
422+
time_covariates_sum = np.broadcast_to(time_covariates_sum, size)
423+
424+
# Calculate exp(time_covariates_sum) for all samples
425+
exp_time_covar_sum = np.exp(time_covariates_sum)
426+
427+
# Initialize output array
428+
output = np.zeros(size, dtype=np.int64)
429+
430+
# For each sample, generate a value from the distribution
431+
for idx in np.ndindex(*size):
432+
# Calculate survival probabilities for each possible value
433+
t = 1
434+
while True:
435+
C_t = t + exp_time_covar_sum[idx]
436+
C_tm1 = (t - 1) + exp_time_covar_sum[idx]
437+
438+
# Calculate PMF for current t
439+
pmf = (
440+
(alpha[idx] / (alpha[idx] + C_tm1)) ** r[idx] -
441+
(alpha[idx] / (alpha[idx] + C_t)) ** r[idx]
442+
)
443+
444+
# If PMF is negative or NaN, we've gone too far
445+
if pmf <= 0 or np.isnan(pmf):
446+
break
447+
448+
# Accept this value with probability proportional to PMF
449+
if rng.random() < pmf:
450+
output[idx] = t
451+
break
452+
453+
t += 1
454+
455+
# Safety check to prevent infinite loops
456+
if t > 1000: # Arbitrary large number
457+
output[idx] = t
458+
break
459+
448460
return output
449461

450462

451463
g2g = GrassiaIIGeometricRV()
452464

453-
# TODO: Add time-varying covariates. May simply replace the t-value , but is a continuous parameter
454-
class GrassiaIIGeometric(Discrete):
465+
# TODO: C expressions are not correct. Both value and covariate broadcasting must be handled.
455466
r"""Grassia(II)-Geometric distribution.
456467
457468
This distribution is a flexible alternative to the Geometric distribution for the number of trials until a
@@ -494,7 +505,9 @@ class GrassiaIIGeometric(Discrete):
494505
Shape parameter (r > 0).
495506
alpha : tensor_like of float
496507
Scale parameter (alpha > 0).
497-
508+
time_covariates_sum : tensor_like of float, optional
509+
Optional dot product of time-varying covariates and their coefficients, summed over time.
510+
498511
References
499512
----------
500513
.. [1] Fader, Peter & G. S. Hardie, Bruce (2020).
@@ -505,22 +518,56 @@ class GrassiaIIGeometric(Discrete):
505518
rv_op = g2g
506519

507520
@classmethod
508-
def dist(cls, r, alpha, *args, **kwargs):
521+
def dist(cls, r, alpha, time_covariates_sum=None, *args, **kwargs):
509522
r = pt.as_tensor_variable(r)
510523
alpha = pt.as_tensor_variable(alpha)
511-
return super().dist([r, alpha], *args, **kwargs)
512-
513-
def logp(value, r, alpha):
514-
logp = -r * (pt.log(alpha + value - 1) + pt.log(alpha + value))
524+
if time_covariates_sum is None:
525+
time_covariates_sum = pt.constant(0.0)
526+
time_covariates_sum = pt.as_tensor_variable(time_covariates_sum)
527+
return super().dist([r, alpha, time_covariates_sum], *args, **kwargs)
515528

529+
def logp(value, r, alpha, time_covariates_sum=None):
530+
"""
531+
Log probability function for GrassiaIIGeometric distribution.
532+
533+
The PMF is:
534+
P(T=t|r,α,β;Z(t)) = (α/(α+C(t-1)))^r - (α/(α+C(t)))^r
535+
536+
where C(t) = t + exp(time_covariates_sum)
537+
"""
538+
if time_covariates_sum is None:
539+
time_covariates_sum = pt.constant(0.0)
540+
541+
# Calculate C(t) and C(t-1)
542+
C_t = value + pt.exp(time_covariates_sum)
543+
C_tm1 = (value - 1) + pt.exp(time_covariates_sum)
544+
545+
# Calculate the PMF on log scale
546+
logp = pt.log(
547+
pt.pow(alpha / (alpha + C_tm1), r) -
548+
pt.pow(alpha / (alpha + C_t), r)
549+
)
550+
551+
# Handle invalid values
552+
logp = pt.switch(
553+
pt.or_(
554+
value < 1, # Value must be >= 1
555+
pt.isnan(logp), # Handle NaN cases
556+
),
557+
-np.inf,
558+
logp
559+
)
560+
516561
return check_parameters(
517562
logp,
518563
r > 0,
519564
alpha > 0,
520-
msg="s > 0, alpha > 0",
565+
msg="r > 0, alpha > 0",
521566
)
522567

523-
def logcdf(value, r, alpha):
568+
def logcdf(value, r, alpha, time_covariates_sum=None):
569+
if time_covariates_sum is not None:
570+
value = time_covariates_sum
524571
logcdf = r * (pt.log(value) - pt.log(alpha + value))
525572

526573
return check_parameters(
@@ -530,15 +577,27 @@ def logcdf(value, r, alpha):
530577
msg="r > 0, alpha > 0",
531578
)
532579

533-
def support_point(rv, size, r, alpha):
580+
def support_point(rv, size, r, alpha, time_covariates_sum=None):
534581
"""Calculate a reasonable starting point for sampling.
535582
536583
For the GrassiaIIGeometric distribution, we use a point estimate based on
537584
the expected value of the mixing distribution. Since the mixing distribution
538585
is Gamma(r, 1/alpha), its mean is r/alpha. We then transform this through
539586
the geometric link function and round to ensure an integer value.
587+
588+
When time_covariates_sum is provided, it affects the expected value through
589+
the exponential link function: exp(time_covariates_sum).
540590
"""
541-
mean = pt.ceil(pt.exp(alpha/r))
591+
# Base mean without covariates
592+
mean = pt.exp(alpha/r)
593+
594+
# Apply time-varying covariates if provided
595+
if time_covariates_sum is None:
596+
time_covariates_sum = pt.constant(0.0)
597+
mean = mean * pt.exp(time_covariates_sum)
598+
599+
# Round up to nearest integer
600+
mean = pt.ceil(mean)
542601

543602
if not rv_size_is_none(size):
544603
mean = pt.full(size, mean)

tests/distributions/test_discrete.py

Lines changed: 74 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -214,23 +214,24 @@ def test_logp(self):
214214
class TestGrassiaIIGeometric:
215215
class TestRandomVariable(BaseTestDistributionRandom):
216216
pymc_dist = GrassiaIIGeometric
217-
pymc_dist_params = {"r": .5, "alpha": 2.0}
218-
expected_rv_op_params = {"r": .5, "alpha": 2.0}
217+
pymc_dist_params = {"r": .5, "alpha": 2.0, "time_covariates_sum": 1.0}
218+
expected_rv_op_params = {"r": .5, "alpha": 2.0, "time_covariates_sum": 1.0}
219219
tests_to_run = [
220220
"check_pymc_params_match_rv_op",
221221
"check_rv_size",
222222
]
223223

224224
def test_random_basic_properties(self):
225-
# Test standard parameter values
225+
# Test standard parameter values with time covariates
226226
discrete_random_tester(
227227
dist=self.pymc_dist,
228228
paramdomains={
229229
"r": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
230230
"alpha": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
231+
"time_covariates_sum": Domain([-1.0, 1.0, 2.0], edges=(None, None)), # Time covariates
231232
},
232-
ref_rand=lambda r, alpha, size: np.random.geometric(
233-
1 - np.exp(-np.random.gamma(r, 1/alpha, size=size)), size=size
233+
ref_rand=lambda r, alpha, time_covariates_sum, size: np.random.geometric(
234+
1 - np.exp(-np.random.gamma(r, 1/alpha, size=size) * np.exp(time_covariates_sum)), size=size
234235
),
235236
)
236237

@@ -240,20 +241,21 @@ def test_random_basic_properties(self):
240241
paramdomains={
241242
"r": Domain([0.01, 0.1], edges=(None, None)), # Small r values
242243
"alpha": Domain([10.0, 100.0], edges=(None, None)), # Large alpha values
244+
"time_covariates_sum": Domain([0.0, 1.0], edges=(None, None)), # Time covariates
243245
},
244-
ref_rand=lambda r, alpha, size: np.random.geometric(
245-
np.clip(np.random.gamma(r, 1/alpha, size=size), 1e-5, 1.0), size=size
246+
ref_rand=lambda r, alpha, time_covariates_sum, size: np.random.geometric(
247+
np.clip(np.random.gamma(r, 1/alpha, size=size) * np.exp(time_covariates_sum), 1e-5, 1.0), size=size
246248
),
247249
)
248250

249-
@pytest.mark.parametrize("r,alpha", [
250-
(0.5, 1.0),
251-
(1.0, 2.0),
252-
(2.0, 0.5),
253-
(5.0, 1.0),
251+
@pytest.mark.parametrize("r,alpha,time_covariates_sum", [
252+
(0.5, 1.0, 0.0),
253+
(1.0, 2.0, 1.0),
254+
(2.0, 0.5, -1.0),
255+
(5.0, 1.0, None),
254256
])
255-
def test_random_moments(self, r, alpha):
256-
dist = self.pymc_dist.dist(r=r, alpha=alpha, size=10_000)
257+
def test_random_moments(self, r, alpha, time_covariates_sum):
258+
dist = self.pymc_dist.dist(r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=10_000)
257259
draws = dist.eval()
258260

259261
# Check that all values are positive integers
@@ -269,65 +271,102 @@ def test_random_moments(self, r, alpha):
269271
def test_logp_basic(self):
270272
r = pt.scalar("r")
271273
alpha = pt.scalar("alpha")
274+
time_covariates_sum = pt.scalar("time_covariates_sum")
272275
value = pt.vector("value", dtype="int64")
273276

274-
logp = pm.logp(GrassiaIIGeometric.dist(r, alpha), value)
275-
logp_fn = pytensor.function([value, r, alpha], logp)
277+
logp = pm.logp(GrassiaIIGeometric.dist(r, alpha, time_covariates_sum), value)
278+
logp_fn = pytensor.function([value, r, alpha, time_covariates_sum], logp)
276279

277280
# Test basic properties of logp
278281
test_value = np.array([1, 2, 3, 4, 5])
279282
test_r = 1.0
280283
test_alpha = 1.0
284+
test_time_covariates_sum = 1.0
281285

282-
logp_vals = logp_fn(test_value, test_r, test_alpha)
286+
logp_vals = logp_fn(test_value, test_r, test_alpha, test_time_covariates_sum)
283287
assert not np.any(np.isnan(logp_vals))
284288
assert np.all(np.isfinite(logp_vals))
285289

286290
# Test invalid values
287-
assert logp_fn(np.array([0]), test_r, test_alpha) == np.inf # Value must be > 0
291+
assert logp_fn(np.array([0]), test_r, test_alpha, test_time_covariates_sum) == np.inf # Value must be > 0
288292

289293
with pytest.raises(TypeError):
290-
logp_fn(np.array([1.5]), test_r, test_alpha) == -np.inf # Value must be integer
294+
logp_fn(np.array([1.5]), test_r, test_alpha, test_time_covariates_sum) # Value must be integer
291295

292296
# Test parameter restrictions
293297
with pytest.raises(ParameterValueError):
294-
logp_fn(np.array([1]), -1.0, test_alpha) # r must be > 0
298+
logp_fn(np.array([1]), -1.0, test_alpha, test_time_covariates_sum) # r must be > 0
295299

296300
with pytest.raises(ParameterValueError):
297-
logp_fn(np.array([1]), test_r, -1.0) # alpha must be > 0
301+
logp_fn(np.array([1]), test_r, -1.0, test_time_covariates_sum) # alpha must be > 0
298302

299303
def test_sampling_consistency(self):
300304
"""Test that sampling from the distribution produces reasonable results"""
301305
r = 2.0
302306
alpha = 1.0
307+
time_covariates_sum = None
308+
309+
# First test direct sampling from the distribution
310+
dist = GrassiaIIGeometric.dist(r=r, alpha=alpha, time_covariates_sum=time_covariates_sum)
311+
direct_samples = dist.eval()
312+
313+
# Convert to numpy array if it's not already
314+
if not isinstance(direct_samples, np.ndarray):
315+
direct_samples = np.array([direct_samples])
316+
317+
# Ensure we have a 1D array
318+
if direct_samples.ndim == 0:
319+
direct_samples = direct_samples.reshape(1)
320+
321+
assert direct_samples.size > 0, "Direct sampling produced no samples"
322+
assert np.all(direct_samples > 0), "Direct sampling produced non-positive values"
323+
assert np.all(direct_samples.astype(int) == direct_samples), "Direct sampling produced non-integer values"
324+
325+
# Then test MCMC sampling
303326
with pm.Model():
304-
x = GrassiaIIGeometric("x", r=r, alpha=alpha)
327+
x = GrassiaIIGeometric("x", r=r, alpha=alpha, time_covariates_sum=time_covariates_sum)
305328
trace = pm.sample(chains=1, draws=1000, random_seed=42).posterior
306329

307-
samples = trace["x"].values.flatten()
330+
# Extract samples and ensure they're in the correct shape
331+
samples = trace["x"].values
332+
assert samples is not None, "No samples were returned from MCMC"
333+
assert samples.size > 0, "MCMC sampling produced empty array"
334+
335+
if samples.ndim > 1:
336+
samples = samples.reshape(-1) # Flatten if needed
308337

309338
# Check basic properties of samples
310-
assert np.all(samples > 0) # All values should be positive
311-
assert np.all(samples.astype(int) == samples) # All values should be integers
339+
assert samples.size > 0, "No samples after reshaping"
340+
assert np.all(samples > 0), "Found non-positive values in samples"
341+
assert np.all(samples.astype(int) == samples), "Found non-integer values in samples"
312342

313343
# Check mean and variance are reasonable
314-
# (exact values depend on the parameterization)
315-
assert 0 < np.mean(samples) < np.inf
316-
assert 0 < np.var(samples) < np.inf
344+
mean = np.mean(samples)
345+
var = np.var(samples)
346+
assert 0 < mean < np.inf, f"Mean {mean} is not in valid range"
347+
assert 0 < var < np.inf, f"Variance {var} is not in valid range"
348+
349+
# Additional checks for distribution properties
350+
# The mean should be greater than 1 for these parameters
351+
assert mean > 1, f"Mean {mean} is not greater than 1"
352+
# The variance should be positive and finite
353+
assert var > 0, f"Variance {var} is not positive"
317354

318355
@pytest.mark.parametrize(
319-
"r, alpha, size, expected_shape",
356+
"r, alpha, time_covariates_sum, size, expected_shape",
320357
[
321-
(1.0, 1.0, None, ()), # Scalar output
322-
([1.0, 2.0], 1.0, None, (2,)), # Vector output from r
323-
(1.0, [1.0, 2.0], None, (2,)), # Vector output from alpha
324-
(1.0, 1.0, (3, 2), (3, 2)), # Explicit size
358+
(1.0, 1.0, 1.0, None, ()), # Scalar output with covariates
359+
([1.0, 2.0], 1.0, 1.0, None, (2,)), # Vector output from r
360+
(1.0, [1.0, 2.0], 1.0, None, (2,)), # Vector output from alpha
361+
(1.0, 1.0, None, None, ()), # No time covariates
362+
(1.0, 1.0, [1.0, 2.0], None, (2,)), # Vector output from time covariates
363+
(1.0, 1.0, 1.0, (3, 2), (3, 2)), # Explicit size
325364
],
326365
)
327-
def test_support_point(self, r, alpha, size, expected_shape):
366+
def test_support_point(self, r, alpha, time_covariates_sum, size, expected_shape):
328367
"""Test that support_point returns reasonable values with correct shapes"""
329368
with pm.Model() as model:
330-
GrassiaIIGeometric("x", r=r, alpha=alpha, size=size)
369+
GrassiaIIGeometric("x", r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=size)
331370

332371
init_point = model.initial_point()["x"]
333372

0 commit comments

Comments
 (0)