Skip to content

Commit 7c7afc8

Browse files
committed
WIP time indexing
1 parent 8a30459 commit 7c7afc8

File tree

2 files changed

+66
-40
lines changed

2 files changed

+66
-40
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -423,40 +423,40 @@ def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
423423

424424
# Calculate exp(time_covariates_sum) for all samples
425425
exp_time_covar_sum = np.exp(time_covariates_sum)
426-
426+
427427
# Initialize output array
428428
output = np.zeros(size, dtype=np.int64)
429-
429+
430430
# For each sample, generate a value from the distribution
431431
for idx in np.ndindex(*size):
432432
# Calculate survival probabilities for each possible value
433433
t = 1
434434
while True:
435435
C_t = t + exp_time_covar_sum[idx]
436436
C_tm1 = (t - 1) + exp_time_covar_sum[idx]
437-
437+
438438
# Calculate PMF for current t
439439
pmf = (
440-
(alpha[idx] / (alpha[idx] + C_tm1)) ** r[idx] -
440+
(alpha[idx] / (alpha[idx] + C_tm1)) ** r[idx] -
441441
(alpha[idx] / (alpha[idx] + C_t)) ** r[idx]
442442
)
443-
443+
444444
# If PMF is negative or NaN, we've gone too far
445445
if pmf <= 0 or np.isnan(pmf):
446446
break
447-
447+
448448
# Accept this value with probability proportional to PMF
449449
if rng.random() < pmf:
450450
output[idx] = t
451451
break
452-
452+
453453
t += 1
454-
454+
455455
# Safety check to prevent infinite loops
456456
if t > 1000: # Arbitrary large number
457457
output[idx] = t
458458
break
459-
459+
460460
return output
461461

462462

@@ -507,7 +507,7 @@ def rng_fn(cls, rng, r, alpha, time_covariates_sum, size):
507507
Scale parameter (alpha > 0).
508508
time_covariates_sum : tensor_like of float, optional
509509
Optional dot product of time-varying covariates and their coefficients, summed over time.
510-
510+
511511
References
512512
----------
513513
.. [1] Fader, Peter & G. S. Hardie, Bruce (2020).
@@ -529,25 +529,25 @@ def dist(cls, r, alpha, time_covariates_sum=None, *args, **kwargs):
529529
def logp(value, r, alpha, time_covariates_sum=None):
530530
"""
531531
Log probability function for GrassiaIIGeometric distribution.
532-
532+
533533
The PMF is:
534534
P(T=t|r,α,β;Z(t)) = (α/(α+C(t-1)))^r - (α/(α+C(t)))^r
535-
535+
536536
where C(t) = t + exp(time_covariates_sum)
537537
"""
538538
if time_covariates_sum is None:
539539
time_covariates_sum = pt.constant(0.0)
540-
540+
541541
# Calculate C(t) and C(t-1)
542542
C_t = value + pt.exp(time_covariates_sum)
543543
C_tm1 = (value - 1) + pt.exp(time_covariates_sum)
544-
544+
545545
# Calculate the PMF on log scale
546546
logp = pt.log(
547-
pt.pow(alpha / (alpha + C_tm1), r) -
547+
pt.pow(alpha / (alpha + C_tm1), r) -
548548
pt.pow(alpha / (alpha + C_t), r)
549549
)
550-
550+
551551
# Handle invalid values
552552
logp = pt.switch(
553553
pt.or_(
@@ -557,7 +557,7 @@ def logp(value, r, alpha, time_covariates_sum=None):
557557
-np.inf,
558558
logp
559559
)
560-
560+
561561
return check_parameters(
562562
logp,
563563
r > 0,

tests/distributions/test_discrete.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ def test_logp(self):
214214
class TestGrassiaIIGeometric:
215215
class TestRandomVariable(BaseTestDistributionRandom):
216216
pymc_dist = GrassiaIIGeometric
217-
pymc_dist_params = {"r": .5, "alpha": 2.0, "time_covariates_sum": 1.0}
218-
expected_rv_op_params = {"r": .5, "alpha": 2.0, "time_covariates_sum": 1.0}
217+
pymc_dist_params = {"r": 0.5, "alpha": 2.0, "time_covariates_sum": 1.0}
218+
expected_rv_op_params = {"r": 0.5, "alpha": 2.0, "time_covariates_sum": 1.0}
219219
tests_to_run = [
220220
"check_pymc_params_match_rv_op",
221221
"check_rv_size",
@@ -228,10 +228,16 @@ def test_random_basic_properties(self):
228228
paramdomains={
229229
"r": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
230230
"alpha": Domain([0.5, 1.0, 2.0], edges=(None, None)), # Standard values
231-
"time_covariates_sum": Domain([-1.0, 1.0, 2.0], edges=(None, None)), # Time covariates
231+
"time_covariates_sum": Domain(
232+
[-1.0, 1.0, 2.0], edges=(None, None)
233+
), # Time covariates
232234
},
233235
ref_rand=lambda r, alpha, time_covariates_sum, size: np.random.geometric(
234-
1 - np.exp(-np.random.gamma(r, 1/alpha, size=size) * np.exp(time_covariates_sum)), size=size
236+
1
237+
- np.exp(
238+
-np.random.gamma(r, 1 / alpha, size=size) * np.exp(time_covariates_sum)
239+
),
240+
size=size,
235241
),
236242
)
237243

@@ -241,21 +247,33 @@ def test_random_basic_properties(self):
241247
paramdomains={
242248
"r": Domain([0.01, 0.1], edges=(None, None)), # Small r values
243249
"alpha": Domain([10.0, 100.0], edges=(None, None)), # Large alpha values
244-
"time_covariates_sum": Domain([0.0, 1.0], edges=(None, None)), # Time covariates
250+
"time_covariates_sum": Domain(
251+
[0.0, 1.0], edges=(None, None)
252+
), # Time covariates
245253
},
246254
ref_rand=lambda r, alpha, time_covariates_sum, size: np.random.geometric(
247-
np.clip(np.random.gamma(r, 1/alpha, size=size) * np.exp(time_covariates_sum), 1e-5, 1.0), size=size
255+
np.clip(
256+
np.random.gamma(r, 1 / alpha, size=size) * np.exp(time_covariates_sum),
257+
1e-5,
258+
1.0,
259+
),
260+
size=size,
248261
),
249262
)
250263

251-
@pytest.mark.parametrize("r,alpha,time_covariates_sum", [
252-
(0.5, 1.0, 0.0),
253-
(1.0, 2.0, 1.0),
254-
(2.0, 0.5, -1.0),
255-
(5.0, 1.0, None),
256-
])
264+
@pytest.mark.parametrize(
265+
"r,alpha,time_covariates_sum",
266+
[
267+
(0.5, 1.0, 0.0),
268+
(1.0, 2.0, 1.0),
269+
(2.0, 0.5, -1.0),
270+
(5.0, 1.0, None),
271+
],
272+
)
257273
def test_random_moments(self, r, alpha, time_covariates_sum):
258-
dist = self.pymc_dist.dist(r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=10_000)
274+
dist = self.pymc_dist.dist(
275+
r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=10_000
276+
)
259277
draws = dist.eval()
260278

261279
# Check that all values are positive integers
@@ -288,10 +306,14 @@ def test_logp_basic(self):
288306
assert np.all(np.isfinite(logp_vals))
289307

290308
# Test invalid values
291-
assert logp_fn(np.array([0]), test_r, test_alpha, test_time_covariates_sum) == np.inf # Value must be > 0
309+
assert (
310+
logp_fn(np.array([0]), test_r, test_alpha, test_time_covariates_sum) == np.inf
311+
) # Value must be > 0
292312

293313
with pytest.raises(TypeError):
294-
logp_fn(np.array([1.5]), test_r, test_alpha, test_time_covariates_sum) # Value must be integer
314+
logp_fn(
315+
np.array([1.5]), test_r, test_alpha, test_time_covariates_sum
316+
) # Value must be integer
295317

296318
# Test parameter restrictions
297319
with pytest.raises(ParameterValueError):
@@ -305,23 +327,25 @@ def test_sampling_consistency(self):
305327
r = 2.0
306328
alpha = 1.0
307329
time_covariates_sum = None
308-
330+
309331
# First test direct sampling from the distribution
310332
dist = GrassiaIIGeometric.dist(r=r, alpha=alpha, time_covariates_sum=time_covariates_sum)
311333
direct_samples = dist.eval()
312-
334+
313335
# Convert to numpy array if it's not already
314336
if not isinstance(direct_samples, np.ndarray):
315337
direct_samples = np.array([direct_samples])
316-
338+
317339
# Ensure we have a 1D array
318340
if direct_samples.ndim == 0:
319341
direct_samples = direct_samples.reshape(1)
320-
342+
321343
assert direct_samples.size > 0, "Direct sampling produced no samples"
322344
assert np.all(direct_samples > 0), "Direct sampling produced non-positive values"
323-
assert np.all(direct_samples.astype(int) == direct_samples), "Direct sampling produced non-integer values"
324-
345+
assert np.all(
346+
direct_samples.astype(int) == direct_samples
347+
), "Direct sampling produced non-integer values"
348+
325349
# Then test MCMC sampling
326350
with pm.Model():
327351
x = GrassiaIIGeometric("x", r=r, alpha=alpha, time_covariates_sum=time_covariates_sum)
@@ -331,7 +355,7 @@ def test_sampling_consistency(self):
331355
samples = trace["x"].values
332356
assert samples is not None, "No samples were returned from MCMC"
333357
assert samples.size > 0, "MCMC sampling produced empty array"
334-
358+
335359
if samples.ndim > 1:
336360
samples = samples.reshape(-1) # Flatten if needed
337361

@@ -366,7 +390,9 @@ def test_sampling_consistency(self):
366390
def test_support_point(self, r, alpha, time_covariates_sum, size, expected_shape):
367391
"""Test that support_point returns reasonable values with correct shapes"""
368392
with pm.Model() as model:
369-
GrassiaIIGeometric("x", r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=size)
393+
GrassiaIIGeometric(
394+
"x", r=r, alpha=alpha, time_covariates_sum=time_covariates_sum, size=size
395+
)
370396

371397
init_point = model.initial_point()["x"]
372398

0 commit comments

Comments
 (0)