Skip to content

Commit 33b6e55

Browse files
committed
Naive Bayes: Add tests for prediction
1 parent 66a93e5 commit 33b6e55

File tree

1 file changed

+133
-23
lines changed

1 file changed

+133
-23
lines changed

Orange/tests/test_naive_bayes.py

Lines changed: 133 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# pylint: disable=missing-docstring
33

44
import unittest
5+
import warnings
56
from unittest.mock import Mock
67

78
import numpy as np
@@ -12,6 +13,12 @@
1213
from Orange.evaluation import CrossValidation, CA
1314

1415

16+
# This class is used to force predict_storage to fall back to the slower
17+
# procedure instead of calling `predict`
18+
class NotATable(Table): # pylint: disable=too-many-ancestors,abstract-method
19+
pass
20+
21+
1522
class TestNaiveBayesLearner(unittest.TestCase):
1623
@classmethod
1724
def setUpClass(cls):
@@ -32,20 +39,6 @@ def test_NaiveBayes(self):
3239
ca = CA(results)
3340
self.assertGreater(ca, 0.7)
3441

35-
def test_predict_single_instance(self):
36-
for ins in self.table:
37-
self.model(ins)
38-
val, prob = self.model(ins, self.model.ValueProbs)
39-
40-
def test_predict_table(self):
41-
self.model(self.table)
42-
vals, probs = self.model(self.table, self.model.ValueProbs)
43-
44-
def test_predict_numpy(self):
45-
X = self.table.X[::20]
46-
self.model(X)
47-
vals, probs = self.model(X, self.model.ValueProbs)
48-
4942
def test_degenerate(self):
5043
d = Domain((ContinuousVariable(name="A"),
5144
ContinuousVariable(name="B"),
@@ -64,15 +57,6 @@ def test_allnan_cv(self):
6457
results = CrossValidation(data, [self.learner])
6558
self.assertFalse(any(results.failed))
6659

67-
def test_sparse(self):
68-
_, dense_p = self.model.predict(self.data.X)
69-
70-
_, csc_p = self.model.predict(sp.csc_matrix(self.data.X))
71-
np.testing.assert_almost_equal(dense_p, csc_p)
72-
73-
_, csr_p = self.model.predict(sp.csr_matrix(self.data.X))
74-
np.testing.assert_almost_equal(dense_p, csr_p)
75-
7660
def test_prediction_routing(self):
7761
data = self.data
7862
predict = self.model.predict = Mock(return_value=(data.Y, None))
@@ -92,6 +76,132 @@ def test_prediction_routing(self):
9276
self.model.predict_storage(data[0])
9377
predict.assert_not_called()
9478

79+
def test_compare_results_of_storage_and_predict_storage(self):
80+
data2 = NotATable("titanic")
81+
82+
self.model = self.learner(self.data[:50])
83+
values, probs = self.model.predict_storage(self.data[50:])
84+
values2, probs2 = self.model.predict_storage(data2[50:])
85+
np.testing.assert_equal(values, values2)
86+
np.testing.assert_equal(probs, probs2)
87+
88+
def test_predictions(self):
89+
self._test_predictions(sparse=None)
90+
91+
def test_predictions_csr_matrix(self):
92+
with warnings.catch_warnings():
93+
warnings.filterwarnings(
94+
"ignore", ".*the matrix subclass.*", PendingDeprecationWarning)
95+
self._test_predictions(sparse=sp.csr_matrix)
96+
97+
def test_predictions_csc_matrix(self):
98+
with warnings.catch_warnings():
99+
warnings.filterwarnings(
100+
"ignore", ".*the matrix subclass.*", PendingDeprecationWarning)
101+
self._test_predictions(sparse=sp.csc_matrix)
102+
103+
def _test_predictions(self, sparse):
104+
x = np.array([
105+
[1, 0, 0],
106+
[0, np.nan, 0],
107+
[0, 1, 0],
108+
[0, 0, 0],
109+
[1, 2, 0],
110+
[1, 1, 0],
111+
[1, 2, 0],
112+
[0, 1, 0]])
113+
if sparse is not None:
114+
x = sparse(x)
115+
116+
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])
117+
domain = Domain(
118+
[DiscreteVariable("a", values="ab"),
119+
DiscreteVariable("b", values="abc"),
120+
DiscreteVariable("c", values="a")],
121+
DiscreteVariable("y", values="abc"))
122+
data = Table.from_numpy(domain, x, y)
123+
124+
model = self.learner(data)
125+
np.testing.assert_almost_equal(
126+
model.class_prob,
127+
[4/11, 4/11, 3/11]
128+
)
129+
np.testing.assert_almost_equal(
130+
np.exp(model.log_cont_prob[0]) * model.class_prob[:, None],
131+
[[3/7, 2/7], [2/7, 3/7], [2/7, 2/7]])
132+
np.testing.assert_almost_equal(
133+
np.exp(model.log_cont_prob[1]) * model.class_prob[:, None],
134+
[[2/5, 1/3, 1/5], [2/5, 1/3, 2/5], [1/5, 1/3, 2/5]])
135+
np.testing.assert_almost_equal(
136+
np.exp(model.log_cont_prob[2]) * model.class_prob[:, None],
137+
[[4/11], [4/11], [3/11]])
138+
139+
test_x = np.array([[a, b, 0] for a in [0, 1] for b in [0, 1, 2]])
140+
# Model.__call__ does not accept csc matrices
141+
# We however test the classifier with csc_matrix (below)
142+
if sparse is not None and sparse is not sp.csc_matrix:
143+
test_x = sparse(test_x)
144+
test_y = np.full((6, ), np.nan)
145+
# The following was computed manually, too
146+
exp_probs = np.array([
147+
[0.47368421052632, 0.31578947368421, 0.21052631578947],
148+
[0.39130434782609, 0.26086956521739, 0.34782608695652],
149+
[0.24324324324324, 0.32432432432432, 0.43243243243243],
150+
[0.31578947368421, 0.47368421052632, 0.21052631578947],
151+
[0.26086956521739, 0.39130434782609, 0.34782608695652],
152+
[0.15000000000000, 0.45000000000000, 0.40000000000000]
153+
])
154+
155+
# Test the faster algorithm for Table (numpy matrices)
156+
test_data = Table.from_numpy(domain, test_x, test_y)
157+
probs = model(test_data, ret=model.Probs)
158+
np.testing.assert_almost_equal(exp_probs, probs)
159+
values = model(test_data)
160+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
161+
values, probs = model(test_data, ret=model.ValueProbs)
162+
np.testing.assert_almost_equal(exp_probs, probs)
163+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
164+
165+
# Test the slower algorithm for non-Table data (iteration in Python)
166+
test_data = NotATable.from_numpy(domain, test_x, test_y)
167+
probs = model(test_data, ret=model.Probs)
168+
np.testing.assert_almost_equal(exp_probs, probs)
169+
values = model(test_data)
170+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
171+
values, probs = model(test_data, ret=model.ValueProbs)
172+
np.testing.assert_almost_equal(exp_probs, probs)
173+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
174+
175+
# Test prediction directly on numpy
176+
probs = model(test_x, ret=model.Probs)
177+
np.testing.assert_almost_equal(exp_probs, probs)
178+
values = model(test_x)
179+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
180+
values, probs = model(test_x, ret=model.ValueProbs)
181+
np.testing.assert_almost_equal(exp_probs, probs)
182+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
183+
184+
# Test prediction on instances
185+
for inst, exp_prob in zip(test_data, exp_probs):
186+
np.testing.assert_almost_equal(
187+
model(inst, ret=model.Probs)[0],
188+
exp_prob)
189+
self.assertEqual(model(inst), np.argmax(exp_prob))
190+
value, prob = model(inst, ret=model.ValueProbs)
191+
np.testing.assert_almost_equal(prob[0], exp_prob)
192+
self.assertEqual(value, np.argmax(exp_prob))
193+
194+
def test_no_attributes(self):
195+
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])
196+
domain = Domain([], DiscreteVariable("y", values="abc"))
197+
data = Table.from_numpy(domain, np.zeros((len(y), 0)), y.T)
198+
test_data = Table.from_numpy(domain, np.zeros((5, 0)), np.zeros((5, 1)))
199+
model = self.learner(data)
200+
np.testing.assert_almost_equal(
201+
model.predict_storage(test_data)[1],
202+
[[4/11, 4/11, 3/11]] * 5
203+
)
204+
95205

96206
if __name__ == "__main__":
97207
unittest.main()

0 commit comments

Comments
 (0)