Skip to content

Commit a125914

Browse files
committed
feat #6686: make tests more readable
1 parent 024c7e9 commit a125914

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

tests/distributions/test_censored.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pytest
17+
import scipy as sp
1718

1819
import pymc as pm
1920

@@ -130,75 +131,89 @@ def test_censored_categorical(self):
130131
def test_censored_logcdf_continuous(self):
131132
norm = pm.Normal.dist(0, 1)
132133
eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf])
134+
expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points)
133135

134136
match_str = "divide by zero encountered in log|invalid value encountered in subtract"
135137

136138
# No censoring
137139
censored_norm = pm.Censored.dist(norm, lower=None, upper=None)
138140
with pytest.warns(RuntimeWarning, match=match_str):
139141
censored_eval = logcdf(censored_norm, eval_points).eval()
140-
with pytest.warns(RuntimeWarning, match=match_str):
141-
norm_eval = logcdf(norm, eval_points).eval()
142-
np.testing.assert_allclose(censored_eval, norm_eval)
142+
np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored)
143143

144144
# Left censoring
145145
censored_norm = pm.Censored.dist(norm, lower=-1, upper=None)
146+
expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
146147
with pytest.warns(RuntimeWarning, match=match_str):
147148
censored_eval = logcdf(censored_norm, eval_points).eval()
148149
np.testing.assert_allclose(
149150
censored_eval,
150-
np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]),
151+
expected_left,
151152
rtol=1e-6,
152153
)
153154

154155
# Right censoring
155156
censored_norm = pm.Censored.dist(norm, lower=None, upper=1)
157+
expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored)
156158
with pytest.warns(RuntimeWarning, match=match_str):
157159
censored_eval = logcdf(censored_norm, eval_points).eval()
158160
np.testing.assert_allclose(
159161
censored_eval,
160-
np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]),
162+
expected_right,
161163
rtol=1e-6,
162164
)
163165

164166
# Interval censoring
165167
censored_norm = pm.Censored.dist(norm, lower=-1, upper=1)
168+
expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored)
169+
expected_interval = np.where(eval_points >= 1, 0.0, expected_interval)
166170
with pytest.warns(RuntimeWarning, match=match_str):
167171
censored_eval = logcdf(censored_norm, eval_points).eval()
168172
np.testing.assert_allclose(
169173
censored_eval,
170-
np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]),
174+
expected_interval,
171175
rtol=1e-6,
172176
)
173177

174178
def test_censored_logcdf_discrete(self):
175-
cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2])
179+
probs = [0.1, 0.2, 0.2, 0.3, 0.2]
180+
cat = pm.Categorical.dist(probs)
176181
eval_points = np.array([-1, 0, 1, 2, 3, 4, 5])
177182

183+
cdf = np.cumsum(probs)
184+
log_cdf_base = np.log(cdf)
185+
expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float)
186+
expected_logcdf_uncensored[1:6] = log_cdf_base
187+
expected_logcdf_uncensored[6] = 0.0
188+
178189
# No censoring
179190
censored_cat = pm.Censored.dist(cat, lower=None, upper=None)
180191
np.testing.assert_allclose(
181192
logcdf(censored_cat, eval_points).eval(),
182-
logcdf(cat, eval_points).eval(),
193+
expected_logcdf_uncensored,
183194
)
184195

185196
# Left censoring
186197
censored_cat = pm.Censored.dist(cat, lower=1, upper=None)
198+
expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
187199
np.testing.assert_allclose(
188200
logcdf(censored_cat, eval_points).eval(),
189-
np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]),
201+
expected_left,
190202
)
191203

192204
# Right censoring
193205
censored_cat = pm.Censored.dist(cat, lower=None, upper=3)
206+
expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored)
194207
np.testing.assert_allclose(
195208
logcdf(censored_cat, eval_points).eval(),
196-
np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]),
209+
expected_right,
197210
)
198211

199212
# Interval censoring
200213
censored_cat = pm.Censored.dist(cat, lower=1, upper=3)
214+
expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored)
215+
expected_interval = np.where(eval_points >= 3, 0.0, expected_interval)
201216
np.testing.assert_allclose(
202217
logcdf(censored_cat, eval_points).eval(),
203-
np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]),
218+
expected_interval,
204219
)

0 commit comments

Comments
 (0)