From db91647d590e393f61aeefda6691bd662b33b1dd Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Thu, 31 Jul 2025 14:53:13 -0500 Subject: [PATCH 1/3] feat-#6686-implementlogcdf-for-censoredRV --- pymc/distributions/censored.py | 30 +++++++++++ tests/distributions/test_censored.py | 76 +++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py index 77c52023b..9fb85cacd 100644 --- a/pymc/distributions/censored.py +++ b/pymc/distributions/censored.py @@ -17,7 +17,9 @@ from pytensor.tensor import TensorVariable from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor.variable import TensorConstant +from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( Distribution, SymbolicRandomVariable, @@ -29,6 +31,7 @@ implicit_size_from_params, rv_size_is_none, ) +from pymc.logprob.abstract import _logcdf from pymc.util import check_dist_not_registered @@ -156,3 +159,30 @@ def support_point_censored(op, rv, dist, lower, upper): ) support_point = pt.full_like(dist, support_point) return support_point + + +@_logcdf.register(CensoredRV) +def censored_logcdf(op, value, *inputs, **kwargs): + base_rv, lower, upper = inputs + + base_rv_op = base_rv.owner.op + base_rv_inputs = base_rv.owner.inputs + logcdf_val = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs) + + is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) + is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) + + if is_lower_bounded: + logcdf_val = pt.switch(pt.lt(value, lower), -np.inf, logcdf_val) + + if is_upper_bounded: + logcdf_val = pt.switch(pt.ge(value, upper), 0.0, logcdf_val) + + if is_lower_bounded and is_upper_bounded: + logcdf_val = check_parameters( + logcdf_val, + pt.le(lower, upper), + msg="lower_bound <= upper_bound", + ) + + return logcdf_val diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index 21dce537b..f3aa7b3e1 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -17,7 +17,7 @@ import pymc as pm -from pymc import logp +from pymc import logcdf, logp from pymc.distributions.shape_utils import change_dist_size @@ -126,3 +126,77 @@ def test_censored_categorical(self): logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), [0, 0, 0.3, 0.2, 0.5, 0, 0], ) + + def test_censored_logcdf_continuous(self): + norm = pm.Normal.dist(0, 1) + eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) + + # No censoring + censored_norm = pm.Censored.dist(norm, lower=None, upper=None) + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + censored_eval = logcdf(censored_norm, eval_points).eval() + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + norm_eval = logcdf(norm, eval_points).eval() + np.testing.assert_allclose(censored_eval, norm_eval) + + # Left censoring + censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]), + rtol=1e-6, + ) + + # Right censoring + censored_norm = pm.Censored.dist(norm, lower=None, upper=1) + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]), + rtol=1e-6, + ) + + # Interval censoring + censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + censored_eval = logcdf(censored_norm, eval_points).eval() + np.testing.assert_allclose( + censored_eval, + np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]), + rtol=1e-6, + ) + + def test_censored_logcdf_discrete(self): + cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2]) + eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) + + # No censoring + censored_cat = pm.Censored.dist(cat, lower=None, upper=None) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + logcdf(cat, eval_points).eval(), + ) + + # Left censoring + censored_cat = pm.Censored.dist(cat, lower=1, upper=None) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]), + ) + + # Right censoring + censored_cat = pm.Censored.dist(cat, lower=None, upper=3) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]), + ) + + # Interval censoring + censored_cat = pm.Censored.dist(cat, lower=1, upper=3) + np.testing.assert_allclose( + logcdf(censored_cat, eval_points).eval(), + np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]), + ) From 024c7e9685081da9104df6414970ae99b0ec50cb Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Thu, 31 Jul 2025 15:37:24 -0500 Subject: [PATCH 2/3] feat #6686: expect invalid value warn --- tests/distributions/test_censored.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index f3aa7b3e1..ed1c8d50f 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -131,17 +131,19 @@ def test_censored_logcdf_continuous(self): norm = pm.Normal.dist(0, 1) eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) + match_str = "divide by zero encountered in log|invalid value encountered in subtract" + # No censoring censored_norm = pm.Censored.dist(norm, lower=None, upper=None) - with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() - with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + with pytest.warns(RuntimeWarning, match=match_str): norm_eval = logcdf(norm, eval_points).eval() np.testing.assert_allclose(censored_eval, norm_eval) # Left censoring censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) - with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, @@ -151,7 +153,7 @@ def test_censored_logcdf_continuous(self): # Right censoring censored_norm = pm.Censored.dist(norm, lower=None, upper=1) - with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, @@ -161,7 +163,7 @@ def test_censored_logcdf_continuous(self): # Interval censoring censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) - with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): + with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, From a125914743d233606ec013d0e6f39d0443b9d4f0 Mon Sep 17 00:00:00 2001 From: Asif Zubair Date: Thu, 31 Jul 2025 21:04:35 -0500 Subject: [PATCH 3/3] feat #6686: make tests more readable --- tests/distributions/test_censored.py | 37 +++++++++++++++++++--------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/tests/distributions/test_censored.py b/tests/distributions/test_censored.py index ed1c8d50f..e5a446665 100644 --- a/tests/distributions/test_censored.py +++ b/tests/distributions/test_censored.py @@ -14,6 +14,7 @@ import numpy as np import pytest +import scipy as sp import pymc as pm @@ -130,6 +131,7 @@ def test_censored_categorical(self): def test_censored_logcdf_continuous(self): norm = pm.Normal.dist(0, 1) eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) + expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points) match_str = "divide by zero encountered in log|invalid value encountered in subtract" @@ -137,68 +139,81 @@ def test_censored_logcdf_continuous(self): censored_norm = pm.Censored.dist(norm, lower=None, upper=None) with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() - with pytest.warns(RuntimeWarning, match=match_str): - norm_eval = logcdf(norm, eval_points).eval() - np.testing.assert_allclose(censored_eval, norm_eval) + np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored) # Left censoring censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) + expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, - np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]), + expected_left, rtol=1e-6, ) # Right censoring censored_norm = pm.Censored.dist(norm, lower=None, upper=1) + expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored) with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, - np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]), + expected_right, rtol=1e-6, ) # Interval censoring censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) + expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) + expected_interval = np.where(eval_points >= 1, 0.0, expected_interval) with pytest.warns(RuntimeWarning, match=match_str): censored_eval = logcdf(censored_norm, eval_points).eval() np.testing.assert_allclose( censored_eval, - np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]), + expected_interval, rtol=1e-6, ) def test_censored_logcdf_discrete(self): - cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2]) + probs = [0.1, 0.2, 0.2, 0.3, 0.2] + cat = pm.Categorical.dist(probs) eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) + cdf = np.cumsum(probs) + log_cdf_base = np.log(cdf) + expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float) + expected_logcdf_uncensored[1:6] = log_cdf_base + expected_logcdf_uncensored[6] = 0.0 + # No censoring censored_cat = pm.Censored.dist(cat, lower=None, upper=None) np.testing.assert_allclose( logcdf(censored_cat, eval_points).eval(), - logcdf(cat, eval_points).eval(), + expected_logcdf_uncensored, ) # Left censoring censored_cat = pm.Censored.dist(cat, lower=1, upper=None) + expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) np.testing.assert_allclose( logcdf(censored_cat, eval_points).eval(), - np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]), + expected_left, ) # Right censoring censored_cat = pm.Censored.dist(cat, lower=None, upper=3) + expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored) np.testing.assert_allclose( logcdf(censored_cat, eval_points).eval(), - np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]), + expected_right, ) # Interval censoring censored_cat = pm.Censored.dist(cat, lower=1, upper=3) + expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) + expected_interval = np.where(eval_points >= 3, 0.0, expected_interval) np.testing.assert_allclose( logcdf(censored_cat, eval_points).eval(), - np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]), + expected_interval, )