From 9ed8b9c811f26691a2556974059c6a60c283e509 Mon Sep 17 00:00:00 2001 From: thocevar Date: Fri, 8 Sep 2017 14:34:54 +0200 Subject: [PATCH 1/2] Explicitly add preprocessors to the list of parameters in the constructor. --- Orange/classification/tree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Orange/classification/tree.py b/Orange/classification/tree.py index b453cee0b50..b2d4adfc034 100644 --- a/Orange/classification/tree.py +++ b/Orange/classification/tree.py @@ -58,8 +58,8 @@ class TreeLearner(Learner): def __init__( self, *args, binarize=False, max_depth=None, min_samples_leaf=1, min_samples_split=2, sufficient_majority=0.95, - **kwargs): - super().__init__(*args, **kwargs) + preprocessors=None, **kwargs): + super().__init__(preprocessors=preprocessors) self.params = {} self.binarize = self.params['binarize'] = binarize self.min_samples_leaf = self.params['min_samples_leaf'] = min_samples_leaf From 445dfd147dde62c187d10d87a12fbf2c591e0690 Mon Sep 17 00:00:00 2001 From: astaric Date: Wed, 13 Sep 2017 23:57:37 +0200 Subject: [PATCH 2/2] Add test --- Orange/tests/test_tree.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/Orange/tests/test_tree.py b/Orange/tests/test_tree.py index 974415c1598..b3342cda68e 100644 --- a/Orange/tests/test_tree.py +++ b/Orange/tests/test_tree.py @@ -2,17 +2,18 @@ # pylint: disable=missing-docstring import unittest +from unittest.mock import Mock import numpy as np import sklearn.tree as skl_tree from sklearn.tree._tree import TREE_LEAF from Orange.data import Table -from Orange.classification import SklTreeLearner +from Orange.classification import SklTreeLearner, TreeLearner from Orange.regression import SklTreeRegressionLearner -class TestTreeLearner(unittest.TestCase): +class TestSklTreeLearner(unittest.TestCase): def test_classification(self): table = Table('iris') learn = SklTreeLearner() @@ -28,6 +29,16 @@ def test_regression(self): self.assertTrue(np.all(table.Y.flatten() == pred)) +class TestTreeLearner(unittest.TestCase): + def test_uses_preprocessors(self): + iris = Table('iris') + mock_preprocessor = Mock(return_value=iris) + + tree = TreeLearner(preprocessors=[mock_preprocessor]) + tree(iris) + mock_preprocessor.assert_called_with(iris) + + class TestDecisionTreeClassifier(unittest.TestCase): @classmethod def setUpClass(cls):