|
17 | 17 |
|
18 | 18 | import pymc as pm |
19 | 19 |
|
20 | | -from pymc import logp |
| 20 | +from pymc import logcdf, logp |
21 | 21 | from pymc.distributions.shape_utils import change_dist_size |
22 | 22 |
|
23 | 23 |
|
@@ -126,3 +126,77 @@ def test_censored_categorical(self): |
126 | 126 | logp(censored_cat, [-1, 0, 1, 2, 3, 4, 5]).exp().eval(), |
127 | 127 | [0, 0, 0.3, 0.2, 0.5, 0, 0], |
128 | 128 | ) |
| 129 | + |
| 130 | + def test_censored_logcdf_continuous(self): |
| 131 | + norm = pm.Normal.dist(0, 1) |
| 132 | + eval_points = np.array([-np.inf, -2, -1, 0, 1, 2, np.inf]) |
| 133 | + |
| 134 | + # No censoring |
| 135 | + censored_norm = pm.Censored.dist(norm, lower=None, upper=None) |
| 136 | + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): |
| 137 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 138 | + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): |
| 139 | + norm_eval = logcdf(norm, eval_points).eval() |
| 140 | + np.testing.assert_allclose(censored_eval, norm_eval) |
| 141 | + |
| 142 | + # Left censoring |
| 143 | + censored_norm = pm.Censored.dist(norm, lower=-1, upper=None) |
| 144 | + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): |
| 145 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 146 | + np.testing.assert_allclose( |
| 147 | + censored_eval, |
| 148 | + np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, -0.17275377, -0.02301291, 0.0]), |
| 149 | + rtol=1e-6, |
| 150 | + ) |
| 151 | + |
| 152 | + # Right censoring |
| 153 | + censored_norm = pm.Censored.dist(norm, lower=None, upper=1) |
| 154 | + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): |
| 155 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 156 | + np.testing.assert_allclose( |
| 157 | + censored_eval, |
| 158 | + np.array([-np.inf, -3.78318435, -1.84102167, -0.69314718, 0, 0, 0.0]), |
| 159 | + rtol=1e-6, |
| 160 | + ) |
| 161 | + |
| 162 | + # Interval censoring |
| 163 | + censored_norm = pm.Censored.dist(norm, lower=-1, upper=1) |
| 164 | + with pytest.warns(RuntimeWarning, match="divide by zero encountered in log"): |
| 165 | + censored_eval = logcdf(censored_norm, eval_points).eval() |
| 166 | + np.testing.assert_allclose( |
| 167 | + censored_eval, |
| 168 | + np.array([-np.inf, -np.inf, -1.84102167, -0.69314718, 0, 0, 0.0]), |
| 169 | + rtol=1e-6, |
| 170 | + ) |
| 171 | + |
| 172 | + def test_censored_logcdf_discrete(self): |
| 173 | + cat = pm.Categorical.dist([0.1, 0.2, 0.2, 0.3, 0.2]) |
| 174 | + eval_points = np.array([-1, 0, 1, 2, 3, 4, 5]) |
| 175 | + |
| 176 | + # No censoring |
| 177 | + censored_cat = pm.Censored.dist(cat, lower=None, upper=None) |
| 178 | + np.testing.assert_allclose( |
| 179 | + logcdf(censored_cat, eval_points).eval(), |
| 180 | + logcdf(cat, eval_points).eval(), |
| 181 | + ) |
| 182 | + |
| 183 | + # Left censoring |
| 184 | + censored_cat = pm.Censored.dist(cat, lower=1, upper=None) |
| 185 | + np.testing.assert_allclose( |
| 186 | + logcdf(censored_cat, eval_points).eval(), |
| 187 | + np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, -0.22314355, 0, 0]), |
| 188 | + ) |
| 189 | + |
| 190 | + # Right censoring |
| 191 | + censored_cat = pm.Censored.dist(cat, lower=None, upper=3) |
| 192 | + np.testing.assert_allclose( |
| 193 | + logcdf(censored_cat, eval_points).eval(), |
| 194 | + np.array([-np.inf, -2.30258509, -1.2039728, -0.69314718, 0, 0, 0]), |
| 195 | + ) |
| 196 | + |
| 197 | + # Interval censoring |
| 198 | + censored_cat = pm.Censored.dist(cat, lower=1, upper=3) |
| 199 | + np.testing.assert_allclose( |
| 200 | + logcdf(censored_cat, eval_points).eval(), |
| 201 | + np.array([-np.inf, -np.inf, -1.2039728, -0.69314718, 0, 0, 0]), |
| 202 | + ) |
0 commit comments