diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 7de2b72e23..27d2817550 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1191,6 +1191,33 @@ def logp(value, p): msg="0 <= p <=1, sum(p) = 1", ) + def icdf(value, p): + eps = 1e-12 + q = value + q_safe = pt.clip(q, 0.0, 1.0 - eps) + cdf = pt.cumsum(p, axis=-1) + + cdf_batch_ndim = cdf.ndim - 1 + q_ndim = q_safe.ndim + if q_ndim < cdf_batch_ndim: + q_safe = pt.shape_padleft(q_safe, cdf_batch_ndim - q_ndim) + elif q_ndim > cdf_batch_ndim: + extra = q_ndim - cdf_batch_ndim + axes = list(range(cdf.ndim - 1)) + ["x"] * extra + [cdf.ndim - 1] + cdf = cdf.dimshuffle(axes) + + mask = pt.shape_padright(q_safe, 1) <= cdf + idx = pt.argmax(mask, axis=-1).astype("int64") + + idx = check_icdf_value(idx, q) + return check_icdf_parameters( + idx, + 0 <= p, + p <= 1, + pt.isclose(pt.sum(p, axis=-1), 1), + msg="0 <= p <=1, sum(p) = 1", + ) + def logcdf(value, p): k = pt.shape(p)[-1] value, safe_value_p = Categorical._safe_index_value_p(value, p.cumsum(-1)) diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 55e8c23128..4c09a16ef8 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -796,6 +796,33 @@ class TestCategorical(BaseTestDistributionRandom): "check_rv_size", ] + @pytest.mark.parametrize("n", [2, 3, 4]) + def test_categorical_icdf(self, n): + paramdomains = {"p": Simplex(n)} + + def numpy_categorical_ppf(q, p): + cdf = np.cumsum(p, axis=-1) + q = np.asarray(q) + return np.argmax(q[..., None] <= cdf, axis=-1) + + check_icdf(pm.Categorical, paramdomains, numpy_categorical_ppf) + + def test_categorical_icdf_batch_shapes(self): + p = np.array([[0.2, 0.3, 0.5], [0.1, 0.1, 0.8]]) + q_vec = np.array([0.0, 0.25]) + dist = pm.Categorical.dist(p=p) + out_vec = icdf(dist, q_vec).eval() + np.testing.assert_array_equal(out_vec, np.array([0, 2])) + q_mat = np.array([[0.05, 0.6, 0.99], [0.21, 0.19, 0.81]]) + out_mat = icdf(dist, q_mat).eval() + np.testing.assert_array_equal(out_mat, np.array([[0, 2, 2], [2, 1, 2]])) + + def test_categorical_icdf_upper_edge(self): + p = np.array([0.1, 0.2, 0.7]) + dist = pm.Categorical.dist(p=p) + out = icdf(dist, np.array([1.0])).eval() + assert out[0] == 2 + class TestLogitCategorical(BaseTestDistributionRandom): pymc_dist = pm.Categorical