|
29 | 29 | import pymc as pm
|
30 | 30 |
|
31 | 31 | from pymc.distributions.discrete import Geometric, _OrderedLogistic, _OrderedProbit
|
32 |
| -from pymc.logprob.abstract import logcdf |
| 32 | +from pymc.logprob.abstract import icdf, logcdf |
33 | 33 | from pymc.logprob.joint_logprob import logp
|
34 | 34 | from pymc.logprob.utils import ParameterValueError
|
35 | 35 | from pymc.pytensorf import floatX
|
@@ -118,13 +118,21 @@ def test_discrete_unif(self):
|
118 | 118 | Domain([-10, 0, 10], "int64"),
|
119 | 119 | {"lower": -Rplusdunif, "upper": Rplusdunif},
|
120 | 120 | )
|
| 121 | + check_icdf( |
| 122 | + pm.DiscreteUniform, |
| 123 | + {"lower": -Rplusdunif, "upper": Rplusdunif}, |
| 124 | + lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), |
| 125 | + skip_paramdomain_outside_edge_test=True, |
| 126 | + ) |
121 | 127 | # Custom logp / logcdf check for invalid parameters
|
122 | 128 | invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0)
|
123 | 129 | with pytensor.config.change_flags(mode=Mode("py")):
|
124 | 130 | with pytest.raises(ParameterValueError):
|
125 | 131 | logp(invalid_dist, 0.5).eval()
|
126 | 132 | with pytest.raises(ParameterValueError):
|
127 | 133 | logcdf(invalid_dist, 2).eval()
|
| 134 | + with pytest.raises(ParameterValueError): |
| 135 | + icdf(invalid_dist, np.array(1)).eval() |
128 | 136 |
|
129 | 137 | def test_geometric(self):
|
130 | 138 | check_logp(
|
|
0 commit comments