Skip to content

Commit 8d70551

Browse files
committed
implement invgamma icdf + test
1 parent 129ec62 commit 8d70551

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

pymc/distributions/continuous.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from pytensor.tensor import gammaln, get_underlying_scalar_constant_value
3030
from pytensor.tensor.exceptions import NotScalarConstantError
3131
from pytensor.tensor.extra_ops import broadcast_shape
32-
from pytensor.tensor.math import betaincinv, gammaincinv, tanh
32+
from pytensor.tensor.math import betaincinv, gammaincinv, gammainccinv, tanh
3333
from pytensor.tensor.random.basic import (
3434
BetaRV,
3535
_gamma,
@@ -2541,6 +2541,16 @@ def logcdf(value, alpha, beta):
25412541
msg="alpha > 0, beta > 0",
25422542
)
25432543

2544+
def icdf(value, alpha, beta):
2545+
res = beta / gammainccinv(alpha, value)
2546+
res = check_icdf_value(res, value)
2547+
return check_icdf_parameters(
2548+
res,
2549+
alpha > 0,
2550+
beta > 0,
2551+
msg="alpha > 0, beta > 0",
2552+
)
2553+
25442554

25452555
class ChiSquared:
25462556
r"""

tests/distributions/test_continuous.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,13 @@ def test_inverse_gamma_logp(self):
675675
lambda value, alpha, beta: st.invgamma.logpdf(value, alpha, scale=beta),
676676
)
677677

678+
def test_inverse_gamma_icdf(self):
679+
check_icdf(
680+
pm.InverseGamma,
681+
{"alpha": Rplusbig, "beta": Rplusbig},
682+
lambda q, alpha, beta: st.invgamma.ppf(q, alpha, scale=beta),
683+
)
684+
678685
@pytest.mark.skipif(
679686
condition=(pytensor.config.floatX == "float32"),
680687
reason="Fails on float32 due to numerical issues",

0 commit comments

Comments
 (0)