|
14 | 14 |
|
15 | 15 | import numpy as np |
16 | 16 | import pytest |
| 17 | +import scipy as sp |
17 | 18 |
|
18 | 19 | import pymc as pm |
19 | 20 |
|
@@ -130,75 +131,89 @@ def test_censored_categorical(self): |
130 | 131 | def test_censored_logcdf_continuous(self): |
131 | 132 | norm = pm.Normal.dist(0, 1) |
132 | 133 | eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) |
| 134 | + expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points) |
133 | 135 |
|
134 | 136 | match_str = "divide by zero encountered in log|invalid value encountered in subtract" |
135 | 137 |
|
136 | 138 | # No censoring |
137 | 139 | censored_norm = pm.Censored.dist(norm, lower=None, upper=None) |
138 | 140 | with pytest.warns(RuntimeWarning, match=match_str): |
139 | 141 | censored_eval = logcdf(censored_norm, eval_points).eval() |
140 | | - with pytest.warns(RuntimeWarning, match=match_str): |
141 | | - norm_eval = logcdf(norm, eval_points).eval() |
142 | | - np.testing.assert_allclose(censored_eval, norm_eval) |
| 142 | + np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored) |
143 | 143 |
|
144 | 144 | # Left censoring |
145 | 145 | censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) |
| 146 | + expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) |
146 | 147 | with pytest.warns(RuntimeWarning, match=match_str): |
147 | 148 | censored_eval = logcdf(censored_norm, eval_points).eval() |
148 | 149 | np.testing.assert_allclose( |
149 | 150 | censored_eval, |
150 | | - np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]), |
| 151 | + expected_left, |
151 | 152 | rtol=1e-6, |
152 | 153 | ) |
153 | 154 |
|
154 | 155 | # Right censoring |
155 | 156 | censored_norm = pm.Censored.dist(norm, lower=None, upper=1) |
| 157 | + expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored) |
156 | 158 | with pytest.warns(RuntimeWarning, match=match_str): |
157 | 159 | censored_eval = logcdf(censored_norm, eval_points).eval() |
158 | 160 | np.testing.assert_allclose( |
159 | 161 | censored_eval, |
160 | | - np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]), |
| 162 | + expected_right, |
161 | 163 | rtol=1e-6, |
162 | 164 | ) |
163 | 165 |
|
164 | 166 | # Interval censoring |
165 | 167 | censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) |
| 168 | + expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) |
| 169 | + expected_interval = np.where(eval_points >= 1, 0.0, expected_interval) |
166 | 170 | with pytest.warns(RuntimeWarning, match=match_str): |
167 | 171 | censored_eval = logcdf(censored_norm, eval_points).eval() |
168 | 172 | np.testing.assert_allclose( |
169 | 173 | censored_eval, |
170 | | - np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]), |
| 174 | + expected_interval, |
171 | 175 | rtol=1e-6, |
172 | 176 | ) |
173 | 177 |
|
174 | 178 | def test_censored_logcdf_discrete(self): |
175 | | - cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2]) |
| 179 | + probs = [0.1, 0.2, 0.2, 0.3, 0.2] |
| 180 | + cat = pm.Categorical.dist(probs) |
176 | 181 | eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) |
177 | 182 |
|
| 183 | + cdf = np.cumsum(probs) |
| 184 | + log_cdf_base = np.log(cdf) |
| 185 | + expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float) |
| 186 | + expected_logcdf_uncensored[1:6] = log_cdf_base |
| 187 | + expected_logcdf_uncensored[6] = 0.0 |
| 188 | + |
178 | 189 | # No censoring |
179 | 190 | censored_cat = pm.Censored.dist(cat, lower=None, upper=None) |
180 | 191 | np.testing.assert_allclose( |
181 | 192 | logcdf(censored_cat, eval_points).eval(), |
182 | | - logcdf(cat, eval_points).eval(), |
| 193 | + expected_logcdf_uncensored, |
183 | 194 | ) |
184 | 195 |
|
185 | 196 | # Left censoring |
186 | 197 | censored_cat = pm.Censored.dist(cat, lower=1, upper=None) |
| 198 | + expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) |
187 | 199 | np.testing.assert_allclose( |
188 | 200 | logcdf(censored_cat, eval_points).eval(), |
189 | | - np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]), |
| 201 | + expected_left, |
190 | 202 | ) |
191 | 203 |
|
192 | 204 | # Right censoring |
193 | 205 | censored_cat = pm.Censored.dist(cat, lower=None, upper=3) |
| 206 | + expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored) |
194 | 207 | np.testing.assert_allclose( |
195 | 208 | logcdf(censored_cat, eval_points).eval(), |
196 | | - np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]), |
| 209 | + expected_right, |
197 | 210 | ) |
198 | 211 |
|
199 | 212 | # Interval censoring |
200 | 213 | censored_cat = pm.Censored.dist(cat, lower=1, upper=3) |
| 214 | + expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) |
| 215 | + expected_interval = np.where(eval_points >= 3, 0.0, expected_interval) |
201 | 216 | np.testing.assert_allclose( |
202 | 217 | logcdf(censored_cat, eval_points).eval(), |
203 | | - np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]), |
| 218 | + expected_interval, |
204 | 219 | ) |
0 commit comments