@@ -214,8 +214,8 @@ def test_logp(self):
214
214
class TestGrassiaIIGeometric :
215
215
class TestRandomVariable (BaseTestDistributionRandom ):
216
216
pymc_dist = GrassiaIIGeometric
217
- pymc_dist_params = {"r" : .5 , "alpha" : 2.0 , "time_covariates_sum" : 1.0 }
218
- expected_rv_op_params = {"r" : .5 , "alpha" : 2.0 , "time_covariates_sum" : 1.0 }
217
+ pymc_dist_params = {"r" : 0 .5 , "alpha" : 2.0 , "time_covariates_sum" : 1.0 }
218
+ expected_rv_op_params = {"r" : 0 .5 , "alpha" : 2.0 , "time_covariates_sum" : 1.0 }
219
219
tests_to_run = [
220
220
"check_pymc_params_match_rv_op" ,
221
221
"check_rv_size" ,
@@ -228,10 +228,16 @@ def test_random_basic_properties(self):
228
228
paramdomains = {
229
229
"r" : Domain ([0.5 , 1.0 , 2.0 ], edges = (None , None )), # Standard values
230
230
"alpha" : Domain ([0.5 , 1.0 , 2.0 ], edges = (None , None )), # Standard values
231
- "time_covariates_sum" : Domain ([- 1.0 , 1.0 , 2.0 ], edges = (None , None )), # Time covariates
231
+ "time_covariates_sum" : Domain (
232
+ [- 1.0 , 1.0 , 2.0 ], edges = (None , None )
233
+ ), # Time covariates
232
234
},
233
235
ref_rand = lambda r , alpha , time_covariates_sum , size : np .random .geometric (
234
- 1 - np .exp (- np .random .gamma (r , 1 / alpha , size = size ) * np .exp (time_covariates_sum )), size = size
236
+ 1
237
+ - np .exp (
238
+ - np .random .gamma (r , 1 / alpha , size = size ) * np .exp (time_covariates_sum )
239
+ ),
240
+ size = size ,
235
241
),
236
242
)
237
243
@@ -241,21 +247,33 @@ def test_random_basic_properties(self):
241
247
paramdomains = {
242
248
"r" : Domain ([0.01 , 0.1 ], edges = (None , None )), # Small r values
243
249
"alpha" : Domain ([10.0 , 100.0 ], edges = (None , None )), # Large alpha values
244
- "time_covariates_sum" : Domain ([0.0 , 1.0 ], edges = (None , None )), # Time covariates
250
+ "time_covariates_sum" : Domain (
251
+ [0.0 , 1.0 ], edges = (None , None )
252
+ ), # Time covariates
245
253
},
246
254
ref_rand = lambda r , alpha , time_covariates_sum , size : np .random .geometric (
247
- np .clip (np .random .gamma (r , 1 / alpha , size = size ) * np .exp (time_covariates_sum ), 1e-5 , 1.0 ), size = size
255
+ np .clip (
256
+ np .random .gamma (r , 1 / alpha , size = size ) * np .exp (time_covariates_sum ),
257
+ 1e-5 ,
258
+ 1.0 ,
259
+ ),
260
+ size = size ,
248
261
),
249
262
)
250
263
251
- @pytest .mark .parametrize ("r,alpha,time_covariates_sum" , [
252
- (0.5 , 1.0 , 0.0 ),
253
- (1.0 , 2.0 , 1.0 ),
254
- (2.0 , 0.5 , - 1.0 ),
255
- (5.0 , 1.0 , None ),
256
- ])
264
+ @pytest .mark .parametrize (
265
+ "r,alpha,time_covariates_sum" ,
266
+ [
267
+ (0.5 , 1.0 , 0.0 ),
268
+ (1.0 , 2.0 , 1.0 ),
269
+ (2.0 , 0.5 , - 1.0 ),
270
+ (5.0 , 1.0 , None ),
271
+ ],
272
+ )
257
273
def test_random_moments (self , r , alpha , time_covariates_sum ):
258
- dist = self .pymc_dist .dist (r = r , alpha = alpha , time_covariates_sum = time_covariates_sum , size = 10_000 )
274
+ dist = self .pymc_dist .dist (
275
+ r = r , alpha = alpha , time_covariates_sum = time_covariates_sum , size = 10_000
276
+ )
259
277
draws = dist .eval ()
260
278
261
279
# Check that all values are positive integers
@@ -288,10 +306,14 @@ def test_logp_basic(self):
288
306
assert np .all (np .isfinite (logp_vals ))
289
307
290
308
# Test invalid values
291
- assert logp_fn (np .array ([0 ]), test_r , test_alpha , test_time_covariates_sum ) == np .inf # Value must be > 0
309
+ assert (
310
+ logp_fn (np .array ([0 ]), test_r , test_alpha , test_time_covariates_sum ) == np .inf
311
+ ) # Value must be > 0
292
312
293
313
with pytest .raises (TypeError ):
294
- logp_fn (np .array ([1.5 ]), test_r , test_alpha , test_time_covariates_sum ) # Value must be integer
314
+ logp_fn (
315
+ np .array ([1.5 ]), test_r , test_alpha , test_time_covariates_sum
316
+ ) # Value must be integer
295
317
296
318
# Test parameter restrictions
297
319
with pytest .raises (ParameterValueError ):
@@ -305,23 +327,25 @@ def test_sampling_consistency(self):
305
327
r = 2.0
306
328
alpha = 1.0
307
329
time_covariates_sum = None
308
-
330
+
309
331
# First test direct sampling from the distribution
310
332
dist = GrassiaIIGeometric .dist (r = r , alpha = alpha , time_covariates_sum = time_covariates_sum )
311
333
direct_samples = dist .eval ()
312
-
334
+
313
335
# Convert to numpy array if it's not already
314
336
if not isinstance (direct_samples , np .ndarray ):
315
337
direct_samples = np .array ([direct_samples ])
316
-
338
+
317
339
# Ensure we have a 1D array
318
340
if direct_samples .ndim == 0 :
319
341
direct_samples = direct_samples .reshape (1 )
320
-
342
+
321
343
assert direct_samples .size > 0 , "Direct sampling produced no samples"
322
344
assert np .all (direct_samples > 0 ), "Direct sampling produced non-positive values"
323
- assert np .all (direct_samples .astype (int ) == direct_samples ), "Direct sampling produced non-integer values"
324
-
345
+ assert np .all (
346
+ direct_samples .astype (int ) == direct_samples
347
+ ), "Direct sampling produced non-integer values"
348
+
325
349
# Then test MCMC sampling
326
350
with pm .Model ():
327
351
x = GrassiaIIGeometric ("x" , r = r , alpha = alpha , time_covariates_sum = time_covariates_sum )
@@ -331,7 +355,7 @@ def test_sampling_consistency(self):
331
355
samples = trace ["x" ].values
332
356
assert samples is not None , "No samples were returned from MCMC"
333
357
assert samples .size > 0 , "MCMC sampling produced empty array"
334
-
358
+
335
359
if samples .ndim > 1 :
336
360
samples = samples .reshape (- 1 ) # Flatten if needed
337
361
@@ -366,7 +390,9 @@ def test_sampling_consistency(self):
366
390
def test_support_point (self , r , alpha , time_covariates_sum , size , expected_shape ):
367
391
"""Test that support_point returns reasonable values with correct shapes"""
368
392
with pm .Model () as model :
369
- GrassiaIIGeometric ("x" , r = r , alpha = alpha , time_covariates_sum = time_covariates_sum , size = size )
393
+ GrassiaIIGeometric (
394
+ "x" , r = r , alpha = alpha , time_covariates_sum = time_covariates_sum , size = size
395
+ )
370
396
371
397
init_point = model .initial_point ()["x" ]
372
398
0 commit comments