Skip to content

Commit d4e89d0

Browse files
committed
[FIX] NaiveBayes: Handle degenerate cases
1 parent 9bf3fa0 commit d4e89d0

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

Orange/classification/naive_bayes.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,17 @@ def predict_storage(self, data):
4747
data = [data]
4848
n_cls = len(self.class_freq)
4949
class_prob = (self.class_freq + 1) / (np.sum(self.class_freq) + n_cls)
50-
log_cont_prob = [np.log(np.divide(np.array(c) + 1,
51-
self.class_freq.reshape((n_cls, 1)) +
52-
c.shape[1])) for c in self.cont]
53-
probs = np.exp(np.array([np.sum(attr_prob[:, int(attr_val)]
54-
for attr_val, attr_prob
55-
in zip(ins, log_cont_prob)
56-
if not np.isnan(attr_val))
57-
for ins in data]) + np.log(class_prob))
50+
if len(data.domain.attributes) == 0:
51+
probs = np.tile(class_prob, (len(data), 1))
52+
else:
53+
log_cont_prob = [np.log(np.divide(np.array(c) + 1,
54+
self.class_freq.reshape((n_cls, 1)) +
55+
c.shape[1])) for c in self.cont]
56+
probs = np.exp(np.array([np.sum(attr_prob[:, int(attr_val)]
57+
for attr_val, attr_prob
58+
in zip(ins, log_cont_prob)
59+
if not np.isnan(attr_val))
60+
for ins in data]) + np.log(class_prob))
5861
probs /= probs.sum(axis=1)[:, None]
5962
values = probs.argmax(axis=1)
6063
return values, probs

Orange/tests/test_naive_bayes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55

66
from Orange.classification import NaiveBayesLearner
7-
from Orange.data import Table
7+
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
88
from Orange.evaluation import CrossValidation, CA
99

1010

@@ -35,3 +35,13 @@ def test_predict_numpy(self):
3535
X = self.table.X[::20]
3636
self.model(X)
3737
vals, probs = self.model(X, self.model.ValueProbs)
38+
39+
def test_degenerate(self):
40+
d = Domain((ContinuousVariable(name="A"), ContinuousVariable(name="B"), ContinuousVariable(name="C")),
41+
DiscreteVariable(name="CLASS", values=["M", "F"]))
42+
t = Table(d, [[0,1,0,0], [0,1,0,1], [0,1,0,1]])
43+
nb = NaiveBayesLearner()
44+
model = nb(t)
45+
self.assertEqual(model.domain.attributes, ())
46+
self.assertEqual(model(t[0]), 1)
47+
self.assertTrue(all(model(t[0]) == 1))

0 commit comments

Comments
 (0)