Skip to content

Commit 86bac08

Browse files
committed
added logit normal icdf and test
1 parent 340e403 commit 86bac08

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3700,6 +3700,12 @@ def logp(value, mu, sigma):
37003700
msg="tau > 0",
37013701
)
37023702

3703+
def icdf(value, mu, sigma):
3704+
# F^{-1}_{LogitNormal}(q) = sigmoid( mu + sigma * Phi^{-1}(q) )
3705+
# where Phi^{-1} is the Normal icdf
3706+
res = invlogit(icdf(Normal.dist(mu, sigma), value))
3707+
res = check_icdf_value(res, value)
3708+
return check_icdf_parameters(res, sigma > 0, msg="sigma > 0")
37033709

37043710
def _interpolated_argcdf(p, pdf, cdf, x):
37053711
if np.prod(cdf.shape[:-1]) != 1 or np.prod(pdf.shape[:-1]) != 1 or np.prod(x.shape[:-1]) != 1:

tests/distributions/test_continuous.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,12 @@ def test_logitnormal(self):
872872
),
873873
decimal=select_by_precision(float64=6, float32=1),
874874
)
875+
check_icdf(
876+
pm.LogitNormal,
877+
{"mu": R, "sigma": Rplus},
878+
lambda q, mu, sigma: sp.expit(mu + sigma * st.norm.ppf(q)),
879+
decimal=select_by_precision(float64=12, float32=5),
880+
)
875881

876882
@pytest.mark.skipif(
877883
condition=(pytensor.config.floatX == "float32"),

0 commit comments

Comments
 (0)