Skip to content

Commit e131cbe

Browse files
authored
Merge pull request #2566 from thocevar/tree
[FIX] Tree: Reintroduce preprocessors.
2 parents 9b9bec1 + 445dfd1 commit e131cbe

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

Orange/classification/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class TreeLearner(Learner):
5858
def __init__(
5959
self, *args, binarize=False, max_depth=None,
6060
min_samples_leaf=1, min_samples_split=2, sufficient_majority=0.95,
61-
**kwargs):
62-
super().__init__(*args, **kwargs)
61+
preprocessors=None, **kwargs):
62+
super().__init__(preprocessors=preprocessors)
6363
self.params = {}
6464
self.binarize = self.params['binarize'] = binarize
6565
self.min_samples_leaf = self.params['min_samples_leaf'] = min_samples_leaf

Orange/tests/test_tree.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,18 @@
22
# pylint: disable=missing-docstring
33

44
import unittest
5+
from unittest.mock import Mock
56

67
import numpy as np
78
import sklearn.tree as skl_tree
89
from sklearn.tree._tree import TREE_LEAF
910

1011
from Orange.data import Table
11-
from Orange.classification import SklTreeLearner
12+
from Orange.classification import SklTreeLearner, TreeLearner
1213
from Orange.regression import SklTreeRegressionLearner
1314

1415

15-
class TestTreeLearner(unittest.TestCase):
16+
class TestSklTreeLearner(unittest.TestCase):
1617
def test_classification(self):
1718
table = Table('iris')
1819
learn = SklTreeLearner()
@@ -28,6 +29,16 @@ def test_regression(self):
2829
self.assertTrue(np.all(table.Y.flatten() == pred))
2930

3031

32+
class TestTreeLearner(unittest.TestCase):
33+
def test_uses_preprocessors(self):
34+
iris = Table('iris')
35+
mock_preprocessor = Mock(return_value=iris)
36+
37+
tree = TreeLearner(preprocessors=[mock_preprocessor])
38+
tree(iris)
39+
mock_preprocessor.assert_called_with(iris)
40+
41+
3142
class TestDecisionTreeClassifier(unittest.TestCase):
3243
@classmethod
3344
def setUpClass(cls):

0 commit comments

Comments
 (0)