|
17 | 17 | from pytensor.graph.replace import clone_replace
|
18 | 18 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery
|
19 | 19 | from pytensor.tensor.random.basic import (
|
| 20 | + _gamma, |
20 | 21 | bernoulli,
|
21 | 22 | beta,
|
22 | 23 | betabinom,
|
@@ -351,20 +352,31 @@ def test_lognormal_samples(mean, sigma, size):
|
351 | 352 | ],
|
352 | 353 | )
|
353 | 354 | def test_gamma_samples(a, b, size):
|
354 |
| - gamma_test_fn = fixed_scipy_rvs("gamma") |
355 |
| - |
356 |
| - def test_fn(shape, rate, **kwargs): |
357 |
| - return gamma_test_fn(shape, scale=1.0 / rate, **kwargs) |
358 |
| - |
359 | 355 | compare_sample_values(
|
360 |
| - gamma, |
| 356 | + _gamma, |
361 | 357 | a,
|
362 | 358 | b,
|
363 | 359 | size=size,
|
364 |
| - test_fn=test_fn, |
365 | 360 | )
|
366 | 361 |
|
367 | 362 |
|
| 363 | +def test_gamma_deprecation_wrapper_fn(): |
| 364 | + out = gamma(5.0, scale=0.5, size=(5,)) |
| 365 | + assert out.type.shape == (5,) |
| 366 | + assert out.owner.inputs[-1].eval() == 0.5 |
| 367 | + |
| 368 | + with pytest.warns(FutureWarning, match="Gamma rate argument is deprecated"): |
| 369 | + out = gamma([5.0, 10.0], 2.0, size=None) |
| 370 | + assert out.type.shape == (2,) |
| 371 | + assert out.owner.inputs[-1].eval() == 0.5 |
| 372 | + |
| 373 | + with pytest.raises(ValueError, match="Must specify scale"): |
| 374 | + gamma(5.0) |
| 375 | + |
| 376 | + with pytest.raises(ValueError, match="Cannot specify both rate and scale"): |
| 377 | + gamma(5.0, rate=2.0, scale=0.5) |
| 378 | + |
| 379 | + |
368 | 380 | @pytest.mark.parametrize(
|
369 | 381 | "df, size",
|
370 | 382 | [
|
@@ -470,18 +482,24 @@ def test_vonmises_samples(mu, kappa, size):
|
470 | 482 |
|
471 | 483 |
|
472 | 484 | @pytest.mark.parametrize(
|
473 |
| - "alpha, size", |
| 485 | + "alpha, scale, size", |
474 | 486 | [
|
475 |
| - (np.array(0.5, dtype=config.floatX), None), |
476 |
| - (np.array(0.5, dtype=config.floatX), []), |
| 487 | + (np.array(0.5, dtype=config.floatX), np.array(3.0, dtype=config.floatX), None), |
| 488 | + (np.array(0.5, dtype=config.floatX), np.array(5.0, dtype=config.floatX), []), |
477 | 489 | (
|
478 | 490 | np.full((1, 2), 0.5, dtype=config.floatX),
|
| 491 | + np.array([0.5, 1.0], dtype=config.floatX), |
479 | 492 | None,
|
480 | 493 | ),
|
481 | 494 | ],
|
482 | 495 | )
|
483 |
| -def test_pareto_samples(alpha, size): |
484 |
| - compare_sample_values(pareto, alpha, size=size, test_fn=fixed_scipy_rvs("pareto")) |
| 496 | +def test_pareto_samples(alpha, scale, size): |
| 497 | + pareto_test_fn = fixed_scipy_rvs("pareto") |
| 498 | + |
| 499 | + def test_fn(shape, scale, **kwargs): |
| 500 | + return pareto_test_fn(shape, scale=scale, **kwargs) |
| 501 | + |
| 502 | + compare_sample_values(pareto, alpha, scale, size=size, test_fn=test_fn) |
485 | 503 |
|
486 | 504 |
|
487 | 505 | def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
|
|
0 commit comments