Skip to content

Commit abd0a63

Browse files
rpgoldmantwiecki
authored andcommitted
Add more tests for missing data. (#3673)
1 parent 639c7bd commit abd0a63

File tree

1 file changed

+62
-5
lines changed

1 file changed

+62
-5
lines changed

pymc3/tests/test_missing.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,85 @@
1-
from pymc3 import Model, Normal
2-
from numpy import ma
1+
import pytest
2+
from numpy import ma, array
33
import numpy
44
import pandas as pd
5-
5+
from pymc3 import Model, Normal, sample_prior_predictive, sample, ImputationWarning
66

77
def test_missing():
88
data = ma.masked_values([1, 2, -1, 4, -1], value=-1)
99
with Model() as model:
1010
x = Normal('x', 1, 1)
11-
Normal('y', x, 1, observed=data)
11+
with pytest.warns(ImputationWarning):
12+
Normal('y', x, 1, observed=data)
1213

1314
y_missing, = model.missing_values
1415
assert y_missing.tag.test_value.shape == (2,)
1516

1617
model.logp(model.test_point)
1718

19+
with model:
20+
prior_trace = sample_prior_predictive()
21+
assert set(['x', 'y']) <= set(prior_trace.keys())
22+
1823

1924
def test_missing_pandas():
2025
data = pd.DataFrame([1, 2, numpy.nan, 4, numpy.nan])
2126
with Model() as model:
2227
x = Normal('x', 1, 1)
23-
Normal('y', x, 1, observed=data)
28+
with pytest.warns(ImputationWarning):
29+
Normal('y', x, 1, observed=data)
30+
31+
y_missing, = model.missing_values
32+
assert y_missing.tag.test_value.shape == (2,)
33+
34+
model.logp(model.test_point)
35+
36+
with model:
37+
prior_trace = sample_prior_predictive()
38+
assert set(['x', 'y']) <= set(prior_trace.keys())
39+
40+
def test_missing_with_predictors():
41+
predictors = array([0.5, 1, 0.5, 2, 0.3])
42+
data = ma.masked_values([1, 2, -1, 4, -1], value=-1)
43+
with Model() as model:
44+
x = Normal('x', 1, 1)
45+
with pytest.warns(ImputationWarning):
46+
Normal('y', x * predictors, 1, observed=data)
2447

2548
y_missing, = model.missing_values
2649
assert y_missing.tag.test_value.shape == (2,)
2750

2851
model.logp(model.test_point)
52+
53+
with model:
54+
prior_trace = sample_prior_predictive()
55+
assert set(['x', 'y']) <= set(prior_trace.keys())
56+
57+
58+
def test_missing_dual_observations():
59+
with Model() as model:
60+
obs1 = ma.masked_values([1, 2, -1, 4, -1], value=-1)
61+
obs2 = ma.masked_values([-1, -1, 6, -1, 8], value=-1)
62+
beta1 = Normal('beta1', 1, 1)
63+
beta2 = Normal('beta2', 2, 1)
64+
latent = Normal('theta', shape=5)
65+
with pytest.warns(ImputationWarning):
66+
ovar1 = Normal('o1', mu=beta1 * latent, observed=obs1)
67+
with pytest.warns(ImputationWarning):
68+
ovar2 = Normal('o2', mu=beta2 * latent, observed=obs2)
69+
70+
prior_trace = sample_prior_predictive()
71+
assert set(['beta1', 'beta2', 'theta', 'o1', 'o2']) <= set(prior_trace.keys())
72+
sample()
73+
74+
def test_internal_missing_observations():
75+
with Model() as model:
76+
obs1 = ma.masked_values([1, 2, -1, 4, -1], value=-1)
77+
obs2 = ma.masked_values([-1, -1, 6, -1, 8], value=-1)
78+
with pytest.warns(ImputationWarning):
79+
theta1 = Normal('theta1', mu=2, observed=obs1)
80+
with pytest.warns(ImputationWarning):
81+
theta2 = Normal('theta2', mu=theta1, observed=obs2)
82+
83+
prior_trace = sample_prior_predictive()
84+
assert set(['theta1', 'theta2']) <= set(prior_trace.keys())
85+
sample()

0 commit comments

Comments
 (0)