diff --git a/Orange/classification/tree.py b/Orange/classification/tree.py index b453cee0b50..b2d4adfc034 100644 --- a/Orange/classification/tree.py +++ b/Orange/classification/tree.py @@ -58,8 +58,8 @@ class TreeLearner(Learner): def __init__( self, *args, binarize=False, max_depth=None, min_samples_leaf=1, min_samples_split=2, sufficient_majority=0.95, - **kwargs): - super().__init__(*args, **kwargs) + preprocessors=None, **kwargs): + super().__init__(preprocessors=preprocessors) self.params = {} self.binarize = self.params['binarize'] = binarize self.min_samples_leaf = self.params['min_samples_leaf'] = min_samples_leaf diff --git a/Orange/tests/test_tree.py b/Orange/tests/test_tree.py index 974415c1598..b3342cda68e 100644 --- a/Orange/tests/test_tree.py +++ b/Orange/tests/test_tree.py @@ -2,17 +2,18 @@ # pylint: disable=missing-docstring import unittest +from unittest.mock import Mock import numpy as np import sklearn.tree as skl_tree from sklearn.tree._tree import TREE_LEAF from Orange.data import Table -from Orange.classification import SklTreeLearner +from Orange.classification import SklTreeLearner, TreeLearner from Orange.regression import SklTreeRegressionLearner -class TestTreeLearner(unittest.TestCase): +class TestSklTreeLearner(unittest.TestCase): def test_classification(self): table = Table('iris') learn = SklTreeLearner() @@ -28,6 +29,16 @@ def test_regression(self): self.assertTrue(np.all(table.Y.flatten() == pred)) +class TestTreeLearner(unittest.TestCase): + def test_uses_preprocessors(self): + iris = Table('iris') + mock_preprocessor = Mock(return_value=iris) + + tree = TreeLearner(preprocessors=[mock_preprocessor]) + tree(iris) + mock_preprocessor.assert_called_with(iris) + + class TestDecisionTreeClassifier(unittest.TestCase): @classmethod def setUpClass(cls):