Skip to content

Commit 0dddfc8

Browse files
committed
Naive Bayes: Add method predict, fix predict_storage
1 parent 0637c96 commit 0dddfc8

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

Orange/classification/naive_bayes.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
2+
import scipy.sparse as sp
23

34
from Orange.classification import Learner, Model
4-
from Orange.data import Instance, Storage
5+
from Orange.data import Instance, Storage, Table
56
from Orange.statistics import contingency
67
from Orange.preprocess import Discretize, RemoveNaNColumns
78

@@ -47,23 +48,81 @@ def __init__(self, log_cont_prob, class_prob, domain):
4748
self.class_prob = class_prob
4849

4950
def predict_storage(self, data):
51+
if type(data) is Table: # pylint: disable=unidiomatic-typecheck
52+
return self.predict(data.X)
53+
5054
if isinstance(data, Instance):
51-
data = [data]
55+
data = Table(data.domain, [data])
5256
if len(data.domain.attributes) == 0:
5357
probs = np.tile(self.class_prob, (len(data), 1))
5458
else:
5559
isnan = np.isnan
56-
probs = np.exp(
60+
probs = np.atleast_2d(np.exp(
5761
np.log(self.class_prob) +
5862
np.array([np.zeros_like(self.class_prob)
5963
if isnan(ins.x).all() else
60-
np.sum(attr_prob[:, int(attr_val)]
61-
for attr_val, attr_prob in zip(ins, self.log_cont_prob)
62-
if not isnan(attr_val))
63-
for ins in data]))
64+
np.sum([attr_prob[:, int(attr_val)]
65+
for attr_val, attr_prob in
66+
zip(ins, self.log_cont_prob)
67+
if not isnan(attr_val)])
68+
for ins in data])))
6469
probs /= probs.sum(axis=1)[:, None]
6570
values = probs.argmax(axis=1)
6671
return values, probs
6772

73+
def predict(self, X):
74+
if not self.log_cont_prob:
75+
probs = self._priors(X)
76+
elif sp.issparse(X):
77+
probs = self._sparse_probs(X)
78+
else:
79+
probs = self._dense_probs(X)
80+
probs = np.exp(probs)
81+
probs /= probs.sum(axis=1)[:, None]
82+
values = probs.argmax(axis=1)
83+
return values, probs
84+
85+
def _priors(self, data):
86+
return np.tile(np.log(self.class_prob), (data.shape[0], 1))
87+
88+
def _dense_probs(self, data):
89+
probs = self._priors(data)
90+
zeros = np.zeros((1, probs.shape[1]))
91+
for col, attr_prob in zip(data.T, self.log_cont_prob):
92+
col = col.copy()
93+
col[np.isnan(col)] = attr_prob.shape[1] - 1
94+
col = col.astype(int)
95+
probs0 = np.vstack((attr_prob.T, zeros))
96+
probs += probs0[col]
97+
return probs
98+
99+
def _sparse_probs(self, data):
100+
probs = self._priors(data)
101+
102+
n_vals = max(p.shape[1] for p in self.log_cont_prob) + 1
103+
log_prob = np.zeros((len(self.log_cont_prob),
104+
n_vals,
105+
self.log_cont_prob[0].shape[0]))
106+
for i, p in enumerate(self.log_cont_prob):
107+
p0 = p.T[0].copy()
108+
probs[:] += p0
109+
log_prob[i, :p.shape[1]] = p.T - p0
110+
111+
dat = data.data.copy()
112+
dat[np.isnan(dat)] = n_vals - 1
113+
dat = dat.astype(int)
114+
115+
if sp.isspmatrix_csr(data):
116+
for row, start, end in zip(probs, data.indptr, data.indptr[1:]):
117+
row += log_prob[data.indices[start:end],
118+
dat[start:end]].sum(axis=0)
119+
else:
120+
csc = data.tocsc()
121+
for start, end, attr_prob in zip(csc.indptr, csc.indptr[1:],
122+
log_prob):
123+
probs[csc.indices[start:end]] += attr_prob[dat[start:end]]
124+
125+
return probs
126+
68127

69128
NaiveBayesLearner.__returns__ = NaiveBayesModel

Orange/tests/test_naive_bayes.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
# pylint: disable=missing-docstring
33

44
import unittest
5+
from unittest.mock import Mock
6+
7+
import numpy as np
8+
import scipy.sparse as sp
59

610
from Orange.classification import NaiveBayesLearner
711
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
@@ -11,17 +15,23 @@
1115
class TestNaiveBayesLearner(unittest.TestCase):
1216
@classmethod
1317
def setUpClass(cls):
14-
data = Table('titanic')
18+
cls.data = data = Table('titanic')
1519
cls.learner = NaiveBayesLearner()
16-
cls.model = cls.learner(data)
1720
cls.table = data[::20]
1821

22+
def setUp(self):
23+
self.model = self.learner(self.data)
24+
1925
def test_NaiveBayes(self):
2026
results = CrossValidation(self.table, [self.learner], k=10)
2127
ca = CA(results)
2228
self.assertGreater(ca, 0.7)
2329
self.assertLess(ca, 0.9)
2430

31+
results = CrossValidation(Table("iris"), [self.learner], k=10)
32+
ca = CA(results)
33+
self.assertGreater(ca, 0.7)
34+
2535
def test_predict_single_instance(self):
2636
for ins in self.table:
2737
self.model(ins)
@@ -53,3 +63,35 @@ def test_allnan_cv(self):
5363
data = Table('voting')
5464
results = CrossValidation(data, [self.learner])
5565
self.assertFalse(any(results.failed))
66+
67+
def test_sparse(self):
68+
_, dense_p = self.model.predict(self.data.X)
69+
70+
_, csc_p = self.model.predict(sp.csc_matrix(self.data.X))
71+
np.testing.assert_almost_equal(dense_p, csc_p)
72+
73+
_, csr_p = self.model.predict(sp.csr_matrix(self.data.X))
74+
np.testing.assert_almost_equal(dense_p, csr_p)
75+
76+
def test_prediction_routing(self):
77+
data = self.data
78+
predict = self.model.predict = Mock(return_value=(data.Y, None))
79+
80+
self.model(data)
81+
predict.assert_called()
82+
predict.reset_mock()
83+
84+
self.model(data.X)
85+
predict.assert_called()
86+
predict.reset_mock()
87+
88+
self.model.predict_storage(data)
89+
predict.assert_called()
90+
predict.reset_mock()
91+
92+
self.model.predict_storage(data[0])
93+
predict.assert_not_called()
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

0 commit comments

Comments
 (0)