Skip to content

Commit 6116c1b

Browse files
committed
Naive Bayes: Reimplement predict_storage to:
- avoid numpy warning about passing generator to sum - loop over columns instead of over rows (usually better + more friendly to pandas sometime in the future) - natively support sparse matrices
1 parent c96c66f commit 6116c1b

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

Orange/classification/naive_bayes.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.sparse as sp
23

34
from Orange.classification import Learner, Model
45
from Orange.data import Instance, Storage
@@ -48,22 +49,62 @@ def __init__(self, log_cont_prob, class_prob, domain):
4849

4950
def predict_storage(self, data):
5051
if isinstance(data, Instance):
51-
data = [data]
52-
if len(data.domain.attributes) == 0:
53-
probs = np.tile(self.class_prob, (len(data), 1))
52+
data = np.atleast_2d(data.x)
53+
elif isinstance(data, Storage):
54+
data = data.X
55+
56+
if not self.log_cont_prob:
57+
probs = self._priors(data)
58+
elif sp.issparse(data):
59+
probs = self._sparse_probs(data)
5460
else:
55-
isnan = np.isnan
56-
probs = np.exp(
57-
np.log(self.class_prob) +
58-
np.array([np.zeros_like(self.class_prob)
59-
if isnan(ins.x).all() else
60-
np.sum(attr_prob[:, int(attr_val)]
61-
for attr_val, attr_prob in zip(ins, self.log_cont_prob)
62-
if not isnan(attr_val))
63-
for ins in data]))
61+
probs = self._dense_probs(data)
62+
probs = np.exp(probs)
6463
probs /= probs.sum(axis=1)[:, None]
6564
values = probs.argmax(axis=1)
6665
return values, probs
6766

67+
def _priors(self, data):
68+
return np.tile(np.log(self.class_prob), (data.shape[0], 1))
69+
70+
def _dense_probs(self, data):
71+
probs = self._priors(data)
72+
zeros = np.zeros((1, probs.shape[1]))
73+
for col, attr_prob in zip(data.T, self.log_cont_prob):
74+
col = col.copy()
75+
col[np.isnan(col)] = attr_prob.shape[1] - 1
76+
col = col.astype(int)
77+
probs0 = np.vstack((attr_prob.T, zeros))
78+
probs += probs0[col]
79+
return probs
80+
81+
def _sparse_probs(self, data):
82+
probs = self._priors(data)
83+
84+
n_vals = max(p.shape[1] for p in self.log_cont_prob) + 1
85+
log_prob = np.zeros((len(self.log_cont_prob),
86+
n_vals,
87+
self.log_cont_prob[0].shape[0]))
88+
for i, p in enumerate(self.log_cont_prob):
89+
p0 = p.T[0].copy()
90+
probs[:] += p0
91+
log_prob[i, :p.shape[1]] = p.T - p0
92+
93+
dat = data.data.copy()
94+
dat[np.isnan(dat)] = n_vals - 1
95+
dat = dat.astype(int)
96+
97+
if sp.isspmatrix_csr(data):
98+
for row, start, end in zip(probs, data.indptr, data.indptr[1:]):
99+
row += log_prob[data.indices[start:end],
100+
dat[start:end]].sum(axis=0)
101+
else:
102+
csc = data.tocsc()
103+
for start, end, attr_prob in zip(csc.indptr, csc.indptr[1:],
104+
log_prob):
105+
probs[csc.indices[start:end]] += attr_prob[dat[start:end]]
106+
107+
return probs
108+
68109

69110
NaiveBayesLearner.__returns__ = NaiveBayesModel

Orange/tests/test_naive_bayes.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import unittest
55

6+
import numpy as np
7+
import scipy.sparse as sp
8+
69
from Orange.classification import NaiveBayesLearner
710
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
811
from Orange.evaluation import CrossValidation, CA
@@ -11,7 +14,7 @@
1114
class TestNaiveBayesLearner(unittest.TestCase):
1215
@classmethod
1316
def setUpClass(cls):
14-
data = Table('titanic')
17+
cls.data = data = Table('titanic')
1518
cls.learner = NaiveBayesLearner()
1619
cls.model = cls.learner(data)
1720
cls.table = data[::20]
@@ -22,6 +25,10 @@ def test_NaiveBayes(self):
2225
self.assertGreater(ca, 0.7)
2326
self.assertLess(ca, 0.9)
2427

28+
results = CrossValidation(Table("iris"), [self.learner], k=10)
29+
ca = CA(results)
30+
self.assertGreater(ca, 0.7)
31+
2532
def test_predict_single_instance(self):
2633
for ins in self.table:
2734
self.model(ins)
@@ -53,3 +60,12 @@ def test_allnan_cv(self):
5360
data = Table('voting')
5461
results = CrossValidation(data, [self.learner])
5562
self.assertFalse(any(results.failed))
63+
64+
def test_sparse(self):
65+
_, dense_p = self.model.predict_storage(self.data.X)
66+
67+
_, csc_p = self.model.predict_storage(sp.csc_matrix(self.data.X))
68+
np.testing.assert_almost_equal(dense_p, csc_p)
69+
70+
_, csr_p = self.model.predict_storage(sp.csr_matrix(self.data.X))
71+
np.testing.assert_almost_equal(dense_p, csr_p)

0 commit comments

Comments
 (0)