diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 77c52023b..9fb85cacd 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -17,7 +17,9 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor.variable import TensorConstant +from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( Distribution, SymbolicRandomVariable, @@ -29,6 +31,7 @@ implicit_size_from_params, rv_size_is_none, ) +from pymc.logprob.abstract import _logcdf from pymc.util import check_dist_not_registered @@ -156,3 +159,30 @@ def support_point_censored(op, rv, dist, lower, upper): ) support_point = pt.full_like(dist, support_point) return support_point + + +@_logcdf.register(CensoredRV) +def censored_logcdf(op, value, *inputs, **kwargs): + base_rv, lower, upper = inputs + + base_rv_op = base_rv.owner.op + base_rv_inputs = base_rv.owner.inputs + logcdf_val = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) + + is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) + + if is_lower_bounded: + logcdf_val = pt.switch(pt.lt(value, lower), -np.inf, logcdf_val) + + if is_upper_bounded: + logcdf_val = pt.switch(pt.ge(value, upper), 0.0, logcdf_val) + + if is_lower_bounded and is_upper_bounded: + logcdf_val = check_parameters( + logcdf_val, + pt.le(lower, upper), + msg="lower_bound <= upper_bound", + ) + + return logcdf_val diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 21dce537b..e5a446665 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -14,10 +14,11 @@ import numpy as np import pytest +import scipy as sp import pymc as pm -from pymc import logp +from pymc import logcdf, logp from pymc.distributions.shape_utils import change_dist_size @@ -126,3 +127,93 @@ def test_censored_categorical(self): logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), [0, 0, 0.3, 0.2, 0.5, 0, 0], ) + + def test_censored_logcdf_continuous(self): + norm = pm.Normal.dist(0, 1) + eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) + expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points) + + match_str = "divide by zero encountered in log|invalid value encountered in subtract" + + # No censoring + censored_norm = pm.Censored.dist(norm, lower=None, upper=None) + with pytest.warns(RuntimeWarning, match=match_str): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored) + + # Left censoring + censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) + expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) + with pytest.warns(RuntimeWarning, match=match_str): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + expected_left, + rtol=1e-6, + ) + + # Right censoring + censored_norm = pm.Censored.dist(norm, lower=None, upper=1) + expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored) + with pytest.warns(RuntimeWarning, match=match_str): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + expected_right, + rtol=1e-6, + ) + + # Interval censoring + censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) + expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) + expected_interval = np.where(eval_points >= 1, 0.0, expected_interval) + with pytest.warns(RuntimeWarning, match=match_str): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + expected_interval, + rtol=1e-6, + ) + + def test_censored_logcdf_discrete(self): + probs = [0.1, 0.2, 0.2, 0.3, 0.2] + cat = pm.Categorical.dist(probs) + eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) + + cdf = np.cumsum(probs) + log_cdf_base = np.log(cdf) + expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float) + expected_logcdf_uncensored[1:6] = log_cdf_base + expected_logcdf_uncensored[6] = 0.0 + + # No censoring + censored_cat = pm.Censored.dist(cat, lower=None, upper=None) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + expected_logcdf_uncensored, + ) + + # Left censoring + censored_cat = pm.Censored.dist(cat, lower=1, upper=None) + expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + expected_left, + ) + + # Right censoring + censored_cat = pm.Censored.dist(cat, lower=None, upper=3) + expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + expected_right, + ) + + # Interval censoring + censored_cat = pm.Censored.dist(cat, lower=1, upper=3) + expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) + expected_interval = np.where(eval_points >= 3, 0.0, expected_interval) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + expected_interval, + )