Skip to content

Commit 127ac30

Browse files
authored
Implement logcdf for CensoredRV (#7884)
1 parent 011fb35 commit 127ac30

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

pymc/distributions/censored.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from pytensor.tensor import TensorVariable
1818
from pytensor.tensor.random.op import RandomVariable
1919
from pytensor.tensor.random.utils import normalize_size_param
20+
from pytensor.tensor.variable import TensorConstant
2021

22+
from pymc.distributions.dist_math import check_parameters
2123
from pymc.distributions.distribution import (
2224
Distribution,
2325
SymbolicRandomVariable,
@@ -29,6 +31,7 @@
2931
implicit_size_from_params,
3032
rv_size_is_none,
3133
)
34+
from pymc.logprob.abstract import _logcdf
3235
from pymc.util import check_dist_not_registered
3336

3437

@@ -156,3 +159,30 @@ def support_point_censored(op, rv, dist, lower, upper):
156159
)
157160
support_point = pt.full_like(dist, support_point)
158161
return support_point
162+
163+
164+
@_logcdf.register(CensoredRV)
165+
def censored_logcdf(op, value, *inputs, **kwargs):
166+
base_rv, lower, upper = inputs
167+
168+
base_rv_op = base_rv.owner.op
169+
base_rv_inputs = base_rv.owner.inputs
170+
logcdf_val = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
171+
172+
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
173+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
174+
175+
if is_lower_bounded:
176+
logcdf_val = pt.switch(pt.lt(value, lower), -np.inf, logcdf_val)
177+
178+
if is_upper_bounded:
179+
logcdf_val = pt.switch(pt.ge(value, upper), 0.0, logcdf_val)
180+
181+
if is_lower_bounded and is_upper_bounded:
182+
logcdf_val = check_parameters(
183+
logcdf_val,
184+
pt.le(lower, upper),
185+
msg="lower_bound <= upper_bound",
186+
)
187+
188+
return logcdf_val

tests/distributions/test_censored.py

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414

1515
import numpy as np
1616
import pytest
17+
import scipy as sp
1718

1819
import pymc as pm
1920

20-
from pymc import logp
21+
from pymc import logcdf, logp
2122
from pymc.distributions.shape_utils import change_dist_size
2223

2324

@@ -126,3 +127,93 @@ def test_censored_categorical(self):
126127
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
127128
[0, 0, 0.3, 0.2, 0.5, 0, 0],
128129
)
130+
131+
def test_censored_logcdf_continuous(self):
132+
norm = pm.Normal.dist(0, 1)
133+
eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf])
134+
expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points)
135+
136+
match_str = "divide by zero encountered in log|invalid value encountered in subtract"
137+
138+
# No censoring
139+
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
140+
with pytest.warns(RuntimeWarning, match=match_str):
141+
censored_eval = logcdf(censored_norm, eval_points).eval()
142+
np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored)
143+
144+
# Left censoring
145+
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
146+
expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
147+
with pytest.warns(RuntimeWarning, match=match_str):
148+
censored_eval = logcdf(censored_norm, eval_points).eval()
149+
np.testing.assert_allclose(
150+
censored_eval,
151+
expected_left,
152+
rtol=1e-6,
153+
)
154+
155+
# Right censoring
156+
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
157+
expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored)
158+
with pytest.warns(RuntimeWarning, match=match_str):
159+
censored_eval = logcdf(censored_norm, eval_points).eval()
160+
np.testing.assert_allclose(
161+
censored_eval,
162+
expected_right,
163+
rtol=1e-6,
164+
)
165+
166+
# Interval censoring
167+
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
168+
expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
169+
expected_interval = np.where(eval_points >= 1, 0.0, expected_interval)
170+
with pytest.warns(RuntimeWarning, match=match_str):
171+
censored_eval = logcdf(censored_norm, eval_points).eval()
172+
np.testing.assert_allclose(
173+
censored_eval,
174+
expected_interval,
175+
rtol=1e-6,
176+
)
177+
178+
def test_censored_logcdf_discrete(self):
179+
probs = [0.1, 0.2, 0.2, 0.3, 0.2]
180+
cat = pm.Categorical.dist(probs)
181+
eval_points = np.array([-1, 0, 1, 2, 3, 4, 5])
182+
183+
cdf = np.cumsum(probs)
184+
log_cdf_base = np.log(cdf)
185+
expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float)
186+
expected_logcdf_uncensored[1:6] = log_cdf_base
187+
expected_logcdf_uncensored[6] = 0.0
188+
189+
# No censoring
190+
censored_cat = pm.Censored.dist(cat, lower=None, upper=None)
191+
np.testing.assert_allclose(
192+
logcdf(censored_cat, eval_points).eval(),
193+
expected_logcdf_uncensored,
194+
)
195+
196+
# Left censoring
197+
censored_cat = pm.Censored.dist(cat, lower=1, upper=None)
198+
expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
199+
np.testing.assert_allclose(
200+
logcdf(censored_cat, eval_points).eval(),
201+
expected_left,
202+
)
203+
204+
# Right censoring
205+
censored_cat = pm.Censored.dist(cat, lower=None, upper=3)
206+
expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored)
207+
np.testing.assert_allclose(
208+
logcdf(censored_cat, eval_points).eval(),
209+
expected_right,
210+
)
211+
212+
# Interval censoring
213+
censored_cat = pm.Censored.dist(cat, lower=1, upper=3)
214+
expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
215+
expected_interval = np.where(eval_points >= 3, 0.0, expected_interval)
216+
np.testing.assert_allclose(
217+
logcdf(censored_cat, eval_points).eval(),
218+
expected_interval,
219+
)

0 commit comments

Comments
 (0)