Skip to content

Commit 67fc41c

Browse files
committed
test naive bayes
1 parent 484becf commit 67fc41c

File tree

1 file changed

+75
-161
lines changed

1 file changed

+75
-161
lines changed

Orange/tests/test_naive_bayes.py

Lines changed: 75 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,42 @@
1010
from Orange.classification import NaiveBayesLearner
1111
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
1212
from Orange.evaluation import CrossValidation, CA
13+
from Orange.tests import test_filename
1314

1415

1516
# This class is used to force predict_storage to fall back to the slower
1617
# procedure instead of calling `predict`
17-
from Orange.tests import test_filename
18-
19-
2018
class NotATable(Table): # pylint: disable=too-many-ancestors,abstract-method
2119
@classmethod
2220
def from_file(cls, *args, **kwargs):
2321
table = super().from_file(*args, **kwargs)
2422
return cls(table)
2523

2624

25+
def assert_predictions_equal(data, model, exp_probs):
26+
exp_vals = np.argmax(np.atleast_2d(exp_probs), axis=1)
27+
np.testing.assert_almost_equal(model(data, ret=model.Probs), exp_probs)
28+
np.testing.assert_equal(model(data), exp_vals)
29+
values, probs = model(data, ret=model.ValueProbs)
30+
np.testing.assert_almost_equal(probs, exp_probs)
31+
np.testing.assert_equal(values, exp_vals)
32+
33+
34+
def assert_model_equal(model, results):
35+
np.testing.assert_almost_equal(
36+
model.class_prob,
37+
results[0])
38+
np.testing.assert_almost_equal(
39+
np.exp(model.log_cont_prob[0]) * model.class_prob[:, None],
40+
results[1])
41+
np.testing.assert_almost_equal(
42+
np.exp(model.log_cont_prob[1]) * model.class_prob[:, None],
43+
results[2])
44+
np.testing.assert_almost_equal(
45+
np.exp(model.log_cont_prob[2]) * model.class_prob[:, None],
46+
results[3])
47+
48+
2749
class TestNaiveBayesLearner(unittest.TestCase):
2850
@classmethod
2951
def setUpClass(cls):
@@ -100,20 +122,22 @@ def test_compare_results_of_predict_and_predict_storage(self):
100122

101123
def test_predictions(self):
102124
self._test_predictions(sparse=None)
103-
self._test_predictions_with_absent_class(sparse=None)
125+
self._test_predictions(sparse=None, absent_class=True)
104126
self._test_predict_missing_attributes(sparse=None)
105127

106128
def test_predictions_csr_matrix(self):
107129
self._test_predictions(sparse=sp.csr_matrix)
108-
self._test_predictions_with_absent_class(sparse=sp.csr_matrix)
130+
self._test_predictions(sparse=sp.csr_matrix, absent_class=True)
109131
self._test_predict_missing_attributes(sparse=sp.csr_matrix)
110132

111133
def test_predictions_csc_matrix(self):
112134
self._test_predictions(sparse=sp.csc_matrix)
113-
self._test_predictions_with_absent_class(sparse=sp.csc_matrix)
135+
self._test_predictions(sparse=sp.csc_matrix, absent_class=True)
114136
self._test_predict_missing_attributes(sparse=sp.csc_matrix)
115137

116-
def _test_predictions(self, sparse):
138+
@staticmethod
139+
def _create_prediction_data(sparse, absent_class=False):
140+
""" The following was computed manually """
117141
x = np.array([
118142
[1, 0, 0],
119143
[0, np.nan, 0],
@@ -127,27 +151,13 @@ def _test_predictions(self, sparse):
127151
x = sparse(x)
128152

129153
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])
130-
domain = Domain(
131-
[DiscreteVariable("a", values="ab"),
132-
DiscreteVariable("b", values="abc"),
133-
DiscreteVariable("c", values="a")],
134-
DiscreteVariable("y", values="abc"))
135-
data = Table.from_numpy(domain, x, y)
136-
137-
model = self.learner(data)
138-
np.testing.assert_almost_equal(
139-
model.class_prob,
140-
[4/11, 4/11, 3/11]
141-
)
142-
np.testing.assert_almost_equal(
143-
np.exp(model.log_cont_prob[0]) * model.class_prob[:, None],
144-
[[3/7, 2/7], [2/7, 3/7], [2/7, 2/7]])
145-
np.testing.assert_almost_equal(
146-
np.exp(model.log_cont_prob[1]) * model.class_prob[:, None],
147-
[[2/5, 1/3, 1/5], [2/5, 1/3, 2/5], [1/5, 1/3, 2/5]])
148-
np.testing.assert_almost_equal(
149-
np.exp(model.log_cont_prob[2]) * model.class_prob[:, None],
150-
[[4/11], [4/11], [3/11]])
154+
class_var = DiscreteVariable("y", values="abc")
155+
results = [
156+
[4/11, 4/11, 3/11],
157+
[[3/7, 2/7], [2/7, 3/7], [2/7, 2/7]],
158+
[[2/5, 1/3, 1/5], [2/5, 1/3, 2/5], [1/5, 1/3, 2/5]],
159+
[[4/11], [4/11], [3/11]]
160+
]
151161

152162
test_x = np.array([[a, b, 0] for a in [0, 1] for b in [0, 1, 2]])
153163
# Classifiers reject csc matrices in the base class
@@ -156,7 +166,7 @@ def _test_predictions(self, sparse):
156166
if sparse is not None and sparse is not sp.csc_matrix:
157167
test_x = sparse(test_x)
158168
test_y = np.full((6, ), np.nan)
159-
# The following was computed manually, too
169+
160170
exp_probs = np.array([
161171
[0.47368421052632, 0.31578947368421, 0.21052631578947],
162172
[0.39130434782609, 0.26086956521739, 0.34782608695652],
@@ -166,155 +176,54 @@ def _test_predictions(self, sparse):
166176
[0.15000000000000, 0.45000000000000, 0.40000000000000]
167177
])
168178

169-
# Test the faster algorithm for Table (numpy matrices)
170-
test_data = Table.from_numpy(domain, test_x, test_y)
171-
probs = model(test_data, ret=model.Probs)
172-
np.testing.assert_almost_equal(exp_probs, probs)
173-
values = model(test_data)
174-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
175-
values, probs = model(test_data, ret=model.ValueProbs)
176-
np.testing.assert_almost_equal(exp_probs, probs)
177-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
178-
179-
# Test the slower algorithm for non-Table data (iteration in Python)
180-
test_data = NotATable.from_numpy(domain, test_x, test_y)
181-
probs = model(test_data, ret=model.Probs)
182-
np.testing.assert_almost_equal(exp_probs, probs)
183-
values = model(test_data)
184-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
185-
values, probs = model(test_data, ret=model.ValueProbs)
186-
np.testing.assert_almost_equal(exp_probs, probs)
187-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
188-
189-
# Test prediction directly on numpy
190-
probs = model(test_x, ret=model.Probs)
191-
np.testing.assert_almost_equal(exp_probs, probs)
192-
values = model(test_x)
193-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
194-
values, probs = model(test_x, ret=model.ValueProbs)
195-
np.testing.assert_almost_equal(exp_probs, probs)
196-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
197-
198-
# Test prediction on instances
199-
for inst, exp_prob in zip(test_data, exp_probs):
200-
np.testing.assert_almost_equal(
201-
model(inst, ret=model.Probs),
202-
exp_prob)
203-
self.assertEqual(model(inst), np.argmax(exp_prob))
204-
value, prob = model(inst, ret=model.ValueProbs)
205-
np.testing.assert_almost_equal(prob, exp_prob)
206-
self.assertEqual(value, np.argmax(exp_prob))
207-
208-
# Test prediction by directly calling predict. This is needed to test
209-
# csc_matrix, but doesn't hurt others
210-
if sparse is sp.csc_matrix:
211-
test_x = sparse(test_x)
212-
values, probs = model.predict(test_x)
213-
np.testing.assert_almost_equal(exp_probs, probs)
214-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
215-
216-
def _test_predictions_with_absent_class(self, sparse):
217-
"""Empty classes should not affect predictions"""
218-
x = np.array([
219-
[1, 0, 0],
220-
[0, np.nan, 0],
221-
[0, 1, 0],
222-
[0, 0, 0],
223-
[1, 2, 0],
224-
[1, 1, 0],
225-
[1, 2, 0],
226-
[0, 1, 0]])
227-
if sparse is not None:
228-
x = sparse(x)
179+
if absent_class:
180+
y = np.array([0, 0, 0, 2, 2, 2, 3, 3])
181+
class_var = DiscreteVariable("y", values="abcd")
182+
for i, row in enumerate(results):
183+
row.insert(1, i and [0]*len(row[0]))
184+
exp_probs = np.insert(exp_probs, 1, 0, axis=1)
229185

230-
y = np.array([0, 0, 0, 2, 2, 2, 3, 3])
231186
domain = Domain(
232187
[DiscreteVariable("a", values="ab"),
233188
DiscreteVariable("b", values="abc"),
234189
DiscreteVariable("c", values="a")],
235-
DiscreteVariable("y", values="abcd"))
190+
class_var)
236191
data = Table.from_numpy(domain, x, y)
237192

238-
model = self.learner(data)
239-
np.testing.assert_almost_equal(
240-
model.class_prob,
241-
[4/11, 0, 4/11, 3/11]
242-
)
243-
np.testing.assert_almost_equal(
244-
np.exp(model.log_cont_prob[0]) * model.class_prob[:, None],
245-
[[3/7, 2/7], [0, 0], [2/7, 3/7], [2/7, 2/7]])
246-
np.testing.assert_almost_equal(
247-
np.exp(model.log_cont_prob[1]) * model.class_prob[:, None],
248-
[[2/5, 1/3, 1/5], [0, 0, 0], [2/5, 1/3, 2/5], [1/5, 1/3, 2/5]])
249-
np.testing.assert_almost_equal(
250-
np.exp(model.log_cont_prob[2]) * model.class_prob[:, None],
251-
[[4/11], [0], [4/11], [3/11]])
193+
return data, domain, results, test_x, test_y, exp_probs
252194

253-
test_x = np.array([[a, b, 0] for a in [0, 1] for b in [0, 1, 2]])
254-
# Classifiers reject csc matrices in the base class
255-
# Naive bayesian classifier supports them if predict_storage is
256-
# called directly, which we do below
257-
if sparse is not None and sparse is not sp.csc_matrix:
258-
test_x = sparse(test_x)
259-
test_y = np.full((6, ), np.nan)
260-
# The following was computed manually, too
261-
exp_probs = np.array([
262-
[0.47368421052632, 0, 0.31578947368421, 0.21052631578947],
263-
[0.39130434782609, 0, 0.26086956521739, 0.34782608695652],
264-
[0.24324324324324, 0, 0.32432432432432, 0.43243243243243],
265-
[0.31578947368421, 0, 0.47368421052632, 0.21052631578947],
266-
[0.26086956521739, 0, 0.39130434782609, 0.34782608695652],
267-
[0.15000000000000, 0, 0.45000000000000, 0.40000000000000]
268-
])
195+
def _test_predictions(self, sparse, absent_class=False):
196+
(data, domain, results,
197+
test_x, test_y, exp_probs) = self._create_prediction_data(sparse, absent_class)
198+
199+
model = self.learner(data)
200+
assert_model_equal(model, results)
269201

270202
# Test the faster algorithm for Table (numpy matrices)
271203
test_data = Table.from_numpy(domain, test_x, test_y)
272-
probs = model(test_data, ret=model.Probs)
273-
np.testing.assert_almost_equal(exp_probs, probs)
274-
values = model(test_data)
275-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
276-
values, probs = model(test_data, ret=model.ValueProbs)
277-
np.testing.assert_almost_equal(exp_probs, probs)
278-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
204+
assert_predictions_equal(test_data, model, exp_probs)
279205

280206
# Test the slower algorithm for non-Table data (iteration in Python)
281207
test_data = NotATable.from_numpy(domain, test_x, test_y)
282-
probs = model(test_data, ret=model.Probs)
283-
np.testing.assert_almost_equal(exp_probs, probs)
284-
values = model(test_data)
285-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
286-
values, probs = model(test_data, ret=model.ValueProbs)
287-
np.testing.assert_almost_equal(exp_probs, probs)
288-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
208+
assert_predictions_equal(test_data, model, exp_probs)
289209

290210
# Test prediction directly on numpy
291-
probs = model(test_x, ret=model.Probs)
292-
np.testing.assert_almost_equal(exp_probs, probs)
293-
values = model(test_x)
294-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
295-
values, probs = model(test_x, ret=model.ValueProbs)
296-
np.testing.assert_almost_equal(exp_probs, probs)
297-
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
211+
assert_predictions_equal(test_x, model, exp_probs)
298212

299213
# Test prediction on instances
300214
for inst, exp_prob in zip(test_data, exp_probs):
301-
np.testing.assert_almost_equal(
302-
model(inst, ret=model.Probs),
303-
exp_prob)
304-
self.assertEqual(model(inst), np.argmax(exp_prob))
305-
value, prob = model(inst, ret=model.ValueProbs)
306-
np.testing.assert_almost_equal(prob, exp_prob)
307-
self.assertEqual(value, np.argmax(exp_prob))
215+
assert_predictions_equal(inst, model, exp_prob)
308216

309217
# Test prediction by directly calling predict. This is needed to test
310218
# csc_matrix, but doesn't hurt others
311219
if sparse is sp.csc_matrix:
312220
test_x = sparse(test_x)
313221
values, probs = model.predict(test_x)
314-
np.testing.assert_almost_equal(exp_probs, probs)
222+
np.testing.assert_almost_equal(probs, exp_probs)
315223
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
316224

317-
def _test_predict_missing_attributes(self, sparse):
225+
@staticmethod
226+
def _create_missing_attributes(sparse):
318227
x = np.array([
319228
[1, 0, 0],
320229
[0, 1, 0],
@@ -325,24 +234,29 @@ def _test_predict_missing_attributes(self, sparse):
325234
[1, 2, np.nan]])
326235
if sparse is not None:
327236
x = sparse(x)
328-
y = np.array([1,0,0,0,1,1,1])
237+
y = np.array([1, 0, 0, 0, 1, 1, 1])
238+
239+
test_x = np.array([[np.nan, np.nan, np.nan],
240+
[np.nan, 0, np.nan],
241+
[0, np.nan, np.nan]])
242+
if sparse is not None and sparse is not sp.csc_matrix:
243+
test_x = sparse(test_x)
244+
exp_probs = np.array([[(3 + 1) / (7 + 2), (4 + 1) / (7 + 2)],
245+
[(1 + 1) / (2 + 2), (1 + 1) / (2 + 2)],
246+
[(3 + 1) / (3 + 2), (0 + 1) / (3 + 2)]])
247+
329248
domain = Domain(
330249
[DiscreteVariable("a", values="ab"),
331250
DiscreteVariable("b", values="abc"),
332251
DiscreteVariable("c", values="a")],
333252
DiscreteVariable("y", values="AB"))
334-
data = Table.from_numpy(domain, x, y)
253+
return Table.from_numpy(domain, x, y), test_x, exp_probs
335254

255+
def _test_predict_missing_attributes(self, sparse):
256+
data, test_x, exp_probs = self._create_missing_attributes(sparse)
336257
model = self.learner(data)
337-
test_x = np.array([[np.nan, np.nan, np.nan],
338-
[np.nan, 0, np.nan],
339-
[0, np.nan, np.nan]])
340-
if sparse is not None and sparse is not sp.csc_matrix:
341-
test_x = sparse(test_x)
342258
probs = model(test_x, ret=model.Probs)
343-
np.testing.assert_almost_equal(probs, [[(3+1)/(7+2), (4+1)/(7+2)],
344-
[(1+1)/(2+2), (1+1)/(2+2)],
345-
[(3+1)/(3+2), (0+1)/(3+2)]])
259+
np.testing.assert_almost_equal(probs, exp_probs)
346260

347261
def test_no_attributes(self):
348262
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])

0 commit comments

Comments
 (0)