Skip to content

Commit cbe2e11

Browse files
committed
fix tree's find_threshold_entropy for repeated values
1 parent 1852c10 commit cbe2e11

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

Orange/classification/_tree_scorers.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def find_threshold_entropy(const double[:] x, const double[:] y,
7979
curr_y = <int>y[idx[i]]
8080
distr[curr_y] -= 1
8181
distr[n_classes + curr_y] += 1
82-
if curr_y != y[idx[i + 1]] and x[idx[i]] != x[idx[i + 1]]:
82+
if x[idx[i]] != x[idx[i + 1]]:
8383
entro = (i + 1) * log(i + 1) + (N - i - 1) * log(N - i - 1)
8484
for j in range(2 * n_classes):
8585
if distr[j]:

Orange/tests/test_orangetree.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import scipy.sparse as sp
8+
from Orange.classification._tree_scorers import find_threshold_entropy
89

910
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
1011
from Orange.classification.tree import \
@@ -448,3 +449,36 @@ def test_compile_and_run_cont_sparse(self):
448449
[14, 2, 1]], dtype=float
449450
))
450451
np.testing.assert_equal(model.get_values(x), expected_values)
452+
453+
454+
class TestScorers(unittest.TestCase):
455+
456+
def test_find_threshold_entropy(self):
457+
x = np.array([1, 2, 3, 4], dtype=float)
458+
y = np.array([0, 0, 1, 1], dtype=float)
459+
ind = np.argsort(x, kind="stable")
460+
e, t = find_threshold_entropy(x, y, ind, 2, 1)
461+
self.assertAlmostEqual(e, 1)
462+
self.assertEqual(t, 2.0)
463+
464+
def test_find_threshold_entropy_repeated(self):
465+
x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
466+
y = np.array([0, 0, 0, 0, 1, 1], dtype=float)
467+
ind = np.argsort(x, kind="stable")
468+
e, t = find_threshold_entropy(x, y, ind, 2, 1)
469+
self.assertAlmostEqual(e, 0.459147917027245)
470+
self.assertEqual(t, 1.0)
471+
472+
x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
473+
y = np.array([0, 0, 1, 1, 1, 1], dtype=float)
474+
ind = np.argsort(x, kind="stable")
475+
e, t = find_threshold_entropy(x, y, ind, 2, 1)
476+
self.assertAlmostEqual(e, 0.459147917027245)
477+
self.assertEqual(t, 1.0)
478+
479+
x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
480+
y = np.array([0, 1, 1, 1, 1, 1], dtype=float)
481+
ind = np.argsort(x, kind="stable")
482+
e, t = find_threshold_entropy(x, y, ind, 2, 1)
483+
self.assertAlmostEqual(e, 0.19087450462110966)
484+
self.assertEqual(t, 1.0)

0 commit comments

Comments
 (0)