Skip to content

Commit db91647

Browse files
committed
feat-#6686-implementlogcdf-for-censoredRV
1 parent 011fb35 commit db91647

File tree

2 files changed

+105
-1
lines changed

2 files changed

+105
-1
lines changed

pymc/distributions/censored.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from pytensor.tensor import TensorVariable
1818
from pytensor.tensor.random.op import RandomVariable
1919
from pytensor.tensor.random.utils import normalize_size_param
20+
from pytensor.tensor.variable import TensorConstant
2021

22+
from pymc.distributions.dist_math import check_parameters
2123
from pymc.distributions.distribution import (
2224
Distribution,
2325
SymbolicRandomVariable,
@@ -29,6 +31,7 @@
2931
implicit_size_from_params,
3032
rv_size_is_none,
3133
)
34+
from pymc.logprob.abstract import _logcdf
3235
from pymc.util import check_dist_not_registered
3336

3437

@@ -156,3 +159,30 @@ def support_point_censored(op, rv, dist, lower, upper):
156159
)
157160
support_point = pt.full_like(dist, support_point)
158161
return support_point
162+
163+
164+
@_logcdf.register(CensoredRV)
165+
def censored_logcdf(op, value, *inputs, **kwargs):
166+
base_rv, lower, upper = inputs
167+
168+
base_rv_op = base_rv.owner.op
169+
base_rv_inputs = base_rv.owner.inputs
170+
logcdf_val = _logcdf(base_rv_op, value, *base_rv_inputs, **kwargs)
171+
172+
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
173+
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
174+
175+
if is_lower_bounded:
176+
logcdf_val = pt.switch(pt.lt(value, lower), -np.inf, logcdf_val)
177+
178+
if is_upper_bounded:
179+
logcdf_val = pt.switch(pt.ge(value, upper), 0.0, logcdf_val)
180+
181+
if is_lower_bounded and is_upper_bounded:
182+
logcdf_val = check_parameters(
183+
logcdf_val,
184+
pt.le(lower, upper),
185+
msg="lower_bound <= upper_bound",
186+
)
187+
188+
return logcdf_val

tests/distributions/test_censored.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pymc as pm
1919

20-
from pymc import logp
20+
from pymc import logcdf, logp
2121
from pymc.distributions.shape_utils import change_dist_size
2222

2323

@@ -126,3 +126,77 @@ def test_censored_categorical(self):
126126
logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
127127
[0, 0, 0.3, 0.2, 0.5, 0, 0],
128128
)
129+
130+
def test_censored_logcdf_continuous(self):
131+
norm = pm.Normal.dist(0, 1)
132+
eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf])
133+
134+
# No censoring
135+
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
136+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
137+
censored_eval = logcdf(censored_norm, eval_points).eval()
138+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
139+
norm_eval = logcdf(norm, eval_points).eval()
140+
np.testing.assert_allclose(censored_eval, norm_eval)
141+
142+
# Left censoring
143+
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
144+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
145+
censored_eval = logcdf(censored_norm, eval_points).eval()
146+
np.testing.assert_allclose(
147+
censored_eval,
148+
np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]),
149+
rtol=1e-6,
150+
)
151+
152+
# Right censoring
153+
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
154+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
155+
censored_eval = logcdf(censored_norm, eval_points).eval()
156+
np.testing.assert_allclose(
157+
censored_eval,
158+
np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]),
159+
rtol=1e-6,
160+
)
161+
162+
# Interval censoring
163+
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
164+
with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"):
165+
censored_eval = logcdf(censored_norm, eval_points).eval()
166+
np.testing.assert_allclose(
167+
censored_eval,
168+
np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]),
169+
rtol=1e-6,
170+
)
171+
172+
def test_censored_logcdf_discrete(self):
173+
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2])
174+
eval_points = np.array([-1, 0, 1, 2, 3, 4, 5])
175+
176+
# No censoring
177+
censored_cat = pm.Censored.dist(cat, lower=None, upper=None)
178+
np.testing.assert_allclose(
179+
logcdf(censored_cat, eval_points).eval(),
180+
logcdf(cat, eval_points).eval(),
181+
)
182+
183+
# Left censoring
184+
censored_cat = pm.Censored.dist(cat, lower=1, upper=None)
185+
np.testing.assert_allclose(
186+
logcdf(censored_cat, eval_points).eval(),
187+
np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]),
188+
)
189+
190+
# Right censoring
191+
censored_cat = pm.Censored.dist(cat, lower=None, upper=3)
192+
np.testing.assert_allclose(
193+
logcdf(censored_cat, eval_points).eval(),
194+
np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]),
195+
)
196+
197+
# Interval censoring
198+
censored_cat = pm.Censored.dist(cat, lower=1, upper=3)
199+
np.testing.assert_allclose(
200+
logcdf(censored_cat, eval_points).eval(),
201+
np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]),
202+
)

0 commit comments

Comments
 (0)