Skip to content

Commit c9e60c9

Browse files
authored
Merge pull request #3540 from janezd/reimplement-nb-predict-storage
[ENH] Naive Bayes: Implement predict, fix predict_storage
2 parents c2735dd + 9fb8bd3 commit c9e60c9

File tree

2 files changed

+249
-25
lines changed

2 files changed

+249
-25
lines changed

Orange/classification/naive_bayes.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2+
import scipy.sparse as sp
23

34
from Orange.classification import Learner, Model
4-
from Orange.data import Instance, Storage
5+
from Orange.data import Instance, Storage, Table
56
from Orange.statistics import contingency
67
from Orange.preprocess import Discretize, RemoveNaNColumns
78

@@ -48,22 +49,80 @@ def __init__(self, log_cont_prob, class_prob, domain):
4849

4950
def predict_storage(self, data):
5051
if isinstance(data, Instance):
51-
data = [data]
52-
if len(data.domain.attributes) == 0:
52+
data = Table(np.atleast_2d(data.x))
53+
if type(data) is Table: # pylint: disable=unidiomatic-typecheck
54+
return self.predict(data.X)
55+
56+
if not len(data) or not len(data[0]):
5357
probs = np.tile(self.class_prob, (len(data), 1))
5458
else:
5559
isnan = np.isnan
56-
probs = np.exp(
60+
zeros = np.zeros_like(self.class_prob)
61+
probs = np.atleast_2d(np.exp(
5762
np.log(self.class_prob) +
58-
np.array([np.zeros_like(self.class_prob)
59-
if isnan(ins.x).all() else
60-
np.sum(attr_prob[:, int(attr_val)]
61-
for attr_val, attr_prob in zip(ins, self.log_cont_prob)
62-
if not isnan(attr_val))
63-
for ins in data]))
63+
np.array([
64+
zeros if isnan(ins.x).all() else
65+
sum(attr_prob[:, int(attr_val)]
66+
for attr_val, attr_prob in zip(ins, self.log_cont_prob)
67+
if not isnan(attr_val))
68+
for ins in data])))
6469
probs /= probs.sum(axis=1)[:, None]
6570
values = probs.argmax(axis=1)
6671
return values, probs
6772

73+
def predict(self, X):
74+
if not self.log_cont_prob:
75+
probs = self._priors(X)
76+
elif sp.issparse(X):
77+
probs = self._sparse_probs(X)
78+
else:
79+
probs = self._dense_probs(X)
80+
probs = np.exp(probs)
81+
probs /= probs.sum(axis=1)[:, None]
82+
values = probs.argmax(axis=1)
83+
return values, probs
84+
85+
def _priors(self, data):
86+
return np.tile(np.log(self.class_prob), (data.shape[0], 1))
87+
88+
def _dense_probs(self, data):
89+
probs = self._priors(data)
90+
zeros = np.zeros((1, probs.shape[1]))
91+
for col, attr_prob in zip(data.T, self.log_cont_prob):
92+
col = col.copy()
93+
col[np.isnan(col)] = attr_prob.shape[1] - 1
94+
col = col.astype(int)
95+
probs0 = np.vstack((attr_prob.T, zeros))
96+
probs += probs0[col]
97+
return probs
98+
99+
def _sparse_probs(self, data):
100+
probs = self._priors(data)
101+
102+
n_vals = max(p.shape[1] for p in self.log_cont_prob) + 1
103+
log_prob = np.zeros((len(self.log_cont_prob),
104+
n_vals,
105+
self.log_cont_prob[0].shape[0]))
106+
for i, p in enumerate(self.log_cont_prob):
107+
p0 = p.T[0].copy()
108+
probs[:] += p0
109+
log_prob[i, :p.shape[1]] = p.T - p0
110+
111+
dat = data.data.copy()
112+
dat[np.isnan(dat)] = n_vals - 1
113+
dat = dat.astype(int)
114+
115+
if sp.isspmatrix_csr(data):
116+
for row, start, end in zip(probs, data.indptr, data.indptr[1:]):
117+
row += log_prob[data.indices[start:end],
118+
dat[start:end]].sum(axis=0)
119+
else:
120+
csc = data.tocsc()
121+
for start, end, attr_prob in zip(csc.indptr, csc.indptr[1:],
122+
log_prob):
123+
probs[csc.indices[start:end]] += attr_prob[dat[start:end]]
124+
125+
return probs
126+
68127

69128
NaiveBayesLearner.__returns__ = NaiveBayesModel

Orange/tests/test_naive_bayes.py

Lines changed: 180 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,42 @@
22
# pylint: disable=missing-docstring
33

44
import unittest
5+
import warnings
6+
from unittest.mock import Mock
7+
8+
import numpy as np
9+
import scipy.sparse as sp
510

611
from Orange.classification import NaiveBayesLearner
712
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
813
from Orange.evaluation import CrossValidation, CA
914

1015

16+
# This class is used to force predict_storage to fall back to the slower
17+
# procedure instead of calling `predict`
18+
class NotATable(Table): # pylint: disable=too-many-ancestors,abstract-method
19+
pass
20+
21+
1122
class TestNaiveBayesLearner(unittest.TestCase):
1223
@classmethod
1324
def setUpClass(cls):
14-
data = Table('titanic')
25+
cls.data = data = Table('titanic')
1526
cls.learner = NaiveBayesLearner()
16-
cls.model = cls.learner(data)
1727
cls.table = data[::20]
1828

29+
def setUp(self):
30+
self.model = self.learner(self.data)
31+
1932
def test_NaiveBayes(self):
2033
results = CrossValidation(self.table, [self.learner], k=10)
2134
ca = CA(results)
2235
self.assertGreater(ca, 0.7)
2336
self.assertLess(ca, 0.9)
2437

25-
def test_predict_single_instance(self):
26-
for ins in self.table:
27-
self.model(ins)
28-
val, prob = self.model(ins, self.model.ValueProbs)
29-
30-
def test_predict_table(self):
31-
self.model(self.table)
32-
vals, probs = self.model(self.table, self.model.ValueProbs)
33-
34-
def test_predict_numpy(self):
35-
X = self.table.X[::20]
36-
self.model(X)
37-
vals, probs = self.model(X, self.model.ValueProbs)
38+
results = CrossValidation(Table("iris"), [self.learner], k=10)
39+
ca = CA(results)
40+
self.assertGreater(ca, 0.7)
3841

3942
def test_degenerate(self):
4043
d = Domain((ContinuousVariable(name="A"),
@@ -53,3 +56,165 @@ def test_allnan_cv(self):
5356
data = Table('voting')
5457
results = CrossValidation(data, [self.learner])
5558
self.assertFalse(any(results.failed))
59+
60+
def test_prediction_routing(self):
61+
data = self.data
62+
predict = self.model.predict = Mock(return_value=(data.Y, None))
63+
64+
self.model(data)
65+
predict.assert_called()
66+
predict.reset_mock()
67+
68+
self.model(data.X)
69+
predict.assert_called()
70+
predict.reset_mock()
71+
72+
self.model.predict_storage(data)
73+
predict.assert_called()
74+
predict.reset_mock()
75+
76+
self.model.predict_storage(data[0])
77+
predict.assert_called()
78+
79+
def test_compare_results_of_predict_and_predict_storage(self):
80+
data2 = NotATable("titanic")
81+
82+
self.model = self.learner(self.data[:50])
83+
predict = self.model.predict = Mock(side_effect=self.model.predict)
84+
values, probs = self.model.predict_storage(self.data[50:])
85+
predict.assert_called()
86+
predict.reset_mock()
87+
values2, probs2 = self.model.predict_storage(data2[50:])
88+
predict.assert_not_called()
89+
90+
np.testing.assert_equal(values, values2)
91+
np.testing.assert_equal(probs, probs2)
92+
93+
def test_predictions(self):
94+
self._test_predictions(sparse=None)
95+
96+
def test_predictions_csr_matrix(self):
97+
with warnings.catch_warnings():
98+
warnings.filterwarnings(
99+
"ignore", ".*the matrix subclass.*", PendingDeprecationWarning)
100+
self._test_predictions(sparse=sp.csr_matrix)
101+
102+
def test_predictions_csc_matrix(self):
103+
with warnings.catch_warnings():
104+
warnings.filterwarnings(
105+
"ignore", ".*the matrix subclass.*", PendingDeprecationWarning)
106+
self._test_predictions(sparse=sp.csc_matrix)
107+
108+
def _test_predictions(self, sparse):
109+
x = np.array([
110+
[1, 0, 0],
111+
[0, np.nan, 0],
112+
[0, 1, 0],
113+
[0, 0, 0],
114+
[1, 2, 0],
115+
[1, 1, 0],
116+
[1, 2, 0],
117+
[0, 1, 0]])
118+
if sparse is not None:
119+
x = sparse(x)
120+
121+
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])
122+
domain = Domain(
123+
[DiscreteVariable("a", values="ab"),
124+
DiscreteVariable("b", values="abc"),
125+
DiscreteVariable("c", values="a")],
126+
DiscreteVariable("y", values="abc"))
127+
data = Table.from_numpy(domain, x, y)
128+
129+
model = self.learner(data)
130+
np.testing.assert_almost_equal(
131+
model.class_prob,
132+
[4/11, 4/11, 3/11]
133+
)
134+
np.testing.assert_almost_equal(
135+
np.exp(model.log_cont_prob[0]) * model.class_prob[:, None],
136+
[[3/7, 2/7], [2/7, 3/7], [2/7, 2/7]])
137+
np.testing.assert_almost_equal(
138+
np.exp(model.log_cont_prob[1]) * model.class_prob[:, None],
139+
[[2/5, 1/3, 1/5], [2/5, 1/3, 2/5], [1/5, 1/3, 2/5]])
140+
np.testing.assert_almost_equal(
141+
np.exp(model.log_cont_prob[2]) * model.class_prob[:, None],
142+
[[4/11], [4/11], [3/11]])
143+
144+
test_x = np.array([[a, b, 0] for a in [0, 1] for b in [0, 1, 2]])
145+
# Classifiers reject csc matrices in the base class
146+
# Naive bayesian classifier supports them if predict_storage is
147+
# called directly, which we do below
148+
if sparse is not None and sparse is not sp.csc_matrix:
149+
test_x = sparse(test_x)
150+
test_y = np.full((6, ), np.nan)
151+
# The following was computed manually, too
152+
exp_probs = np.array([
153+
[0.47368421052632, 0.31578947368421, 0.21052631578947],
154+
[0.39130434782609, 0.26086956521739, 0.34782608695652],
155+
[0.24324324324324, 0.32432432432432, 0.43243243243243],
156+
[0.31578947368421, 0.47368421052632, 0.21052631578947],
157+
[0.26086956521739, 0.39130434782609, 0.34782608695652],
158+
[0.15000000000000, 0.45000000000000, 0.40000000000000]
159+
])
160+
161+
# Test the faster algorithm for Table (numpy matrices)
162+
test_data = Table.from_numpy(domain, test_x, test_y)
163+
probs = model(test_data, ret=model.Probs)
164+
np.testing.assert_almost_equal(exp_probs, probs)
165+
values = model(test_data)
166+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
167+
values, probs = model(test_data, ret=model.ValueProbs)
168+
np.testing.assert_almost_equal(exp_probs, probs)
169+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
170+
171+
# Test the slower algorithm for non-Table data (iteration in Python)
172+
test_data = NotATable.from_numpy(domain, test_x, test_y)
173+
probs = model(test_data, ret=model.Probs)
174+
np.testing.assert_almost_equal(exp_probs, probs)
175+
values = model(test_data)
176+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
177+
values, probs = model(test_data, ret=model.ValueProbs)
178+
np.testing.assert_almost_equal(exp_probs, probs)
179+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
180+
181+
# Test prediction directly on numpy
182+
probs = model(test_x, ret=model.Probs)
183+
np.testing.assert_almost_equal(exp_probs, probs)
184+
values = model(test_x)
185+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
186+
values, probs = model(test_x, ret=model.ValueProbs)
187+
np.testing.assert_almost_equal(exp_probs, probs)
188+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
189+
190+
# Test prediction on instances
191+
for inst, exp_prob in zip(test_data, exp_probs):
192+
np.testing.assert_almost_equal(
193+
model(inst, ret=model.Probs)[0],
194+
exp_prob)
195+
self.assertEqual(model(inst), np.argmax(exp_prob))
196+
value, prob = model(inst, ret=model.ValueProbs)
197+
np.testing.assert_almost_equal(prob[0], exp_prob)
198+
self.assertEqual(value, np.argmax(exp_prob))
199+
200+
# Test prediction by directly calling predict. This is needed to test
201+
# csc_matrix, but doesn't hurt others
202+
if sparse is sp.csc_matrix:
203+
test_x = sparse(test_x)
204+
values, probs = model.predict(test_x)
205+
np.testing.assert_almost_equal(exp_probs, probs)
206+
np.testing.assert_equal(values, np.argmax(exp_probs, axis=1))
207+
208+
def test_no_attributes(self):
209+
y = np.array([0, 0, 0, 1, 1, 1, 2, 2])
210+
domain = Domain([], DiscreteVariable("y", values="abc"))
211+
data = Table.from_numpy(domain, np.zeros((len(y), 0)), y.T)
212+
model = self.learner(data)
213+
np.testing.assert_almost_equal(
214+
model.predict_storage(np.zeros((5, 0)))[1],
215+
[[4/11, 4/11, 3/11]] * 5
216+
)
217+
218+
219+
if __name__ == "__main__":
220+
unittest.main()

0 commit comments

Comments
 (0)