|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 | 16 | import pytest
|
| 17 | +import scipy as sp |
17 | 18 |
|
18 | 19 | import pymc as pm
|
19 | 20 |
|
20 |
| -from pymc import logp |
| 21 | +from pymc import logcdf, logp |
21 | 22 | from pymc.distributions.shape_utils import change_dist_size
|
22 | 23 |
|
23 | 24 |
|
@@ -126,3 +127,93 @@ def test_censored_categorical(self):
|
126 | 127 | logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(),
|
127 | 128 | [0, 0, 0.3, 0.2, 0.5, 0, 0],
|
128 | 129 | )
|
| 130 | + |
| 131 | + def test_censored_logcdf_continuous(self): |
| 132 | + norm = pm.Normal.dist(0, 1) |
| 133 | + eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) |
| 134 | + expected_logcdf_uncensored = sp.stats.norm.logcdf(eval_points) |
| 135 | + |
| 136 | + match_str = "divide by zero encountered in log|invalid value encountered in subtract" |
| 137 | + |
| 138 | + # No censoring |
| 139 | + censored_norm = pm.Censored.dist(norm, lower=None, upper=None) |
| 140 | + with pytest.warns(RuntimeWarning, match=match_str): |
| 141 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 142 | + np.testing.assert_allclose(censored_eval, expected_logcdf_uncensored) |
| 143 | + |
| 144 | + # Left censoring |
| 145 | + censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) |
| 146 | + expected_left = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) |
| 147 | + with pytest.warns(RuntimeWarning, match=match_str): |
| 148 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 149 | + np.testing.assert_allclose( |
| 150 | + censored_eval, |
| 151 | + expected_left, |
| 152 | + rtol=1e-6, |
| 153 | + ) |
| 154 | + |
| 155 | + # Right censoring |
| 156 | + censored_norm = pm.Censored.dist(norm, lower=None, upper=1) |
| 157 | + expected_right = np.where(eval_points >= 1, 0.0, expected_logcdf_uncensored) |
| 158 | + with pytest.warns(RuntimeWarning, match=match_str): |
| 159 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 160 | + np.testing.assert_allclose( |
| 161 | + censored_eval, |
| 162 | + expected_right, |
| 163 | + rtol=1e-6, |
| 164 | + ) |
| 165 | + |
| 166 | + # Interval censoring |
| 167 | + censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) |
| 168 | + expected_interval = np.where(eval_points < -1, -np.inf, expected_logcdf_uncensored) |
| 169 | + expected_interval = np.where(eval_points >= 1, 0.0, expected_interval) |
| 170 | + with pytest.warns(RuntimeWarning, match=match_str): |
| 171 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 172 | + np.testing.assert_allclose( |
| 173 | + censored_eval, |
| 174 | + expected_interval, |
| 175 | + rtol=1e-6, |
| 176 | + ) |
| 177 | + |
| 178 | + def test_censored_logcdf_discrete(self): |
| 179 | + probs = [0.1, 0.2, 0.2, 0.3, 0.2] |
| 180 | + cat = pm.Categorical.dist(probs) |
| 181 | + eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) |
| 182 | + |
| 183 | + cdf = np.cumsum(probs) |
| 184 | + log_cdf_base = np.log(cdf) |
| 185 | + expected_logcdf_uncensored = np.full_like(eval_points, -np.inf, dtype=float) |
| 186 | + expected_logcdf_uncensored[1:6] = log_cdf_base |
| 187 | + expected_logcdf_uncensored[6] = 0.0 |
| 188 | + |
| 189 | + # No censoring |
| 190 | + censored_cat = pm.Censored.dist(cat, lower=None, upper=None) |
| 191 | + np.testing.assert_allclose( |
| 192 | + logcdf(censored_cat, eval_points).eval(), |
| 193 | + expected_logcdf_uncensored, |
| 194 | + ) |
| 195 | + |
| 196 | + # Left censoring |
| 197 | + censored_cat = pm.Censored.dist(cat, lower=1, upper=None) |
| 198 | + expected_left = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) |
| 199 | + np.testing.assert_allclose( |
| 200 | + logcdf(censored_cat, eval_points).eval(), |
| 201 | + expected_left, |
| 202 | + ) |
| 203 | + |
| 204 | + # Right censoring |
| 205 | + censored_cat = pm.Censored.dist(cat, lower=None, upper=3) |
| 206 | + expected_right = np.where(eval_points >= 3, 0.0, expected_logcdf_uncensored) |
| 207 | + np.testing.assert_allclose( |
| 208 | + logcdf(censored_cat, eval_points).eval(), |
| 209 | + expected_right, |
| 210 | + ) |
| 211 | + |
| 212 | + # Interval censoring |
| 213 | + censored_cat = pm.Censored.dist(cat, lower=1, upper=3) |
| 214 | + expected_interval = np.where(eval_points < 1, -np.inf, expected_logcdf_uncensored) |
| 215 | + expected_interval = np.where(eval_points >= 3, 0.0, expected_interval) |
| 216 | + np.testing.assert_allclose( |
| 217 | + logcdf(censored_cat, eval_points).eval(), |
| 218 | + expected_interval, |
| 219 | + ) |
0 commit comments