Skip to content

Commit 75125ad

Browse files
authored
Merge pull request #1818 from pavlin-policar/merge-tree
Merge OWTree
2 parents b890868 + 3408f7d commit 75125ad

File tree

9 files changed

+207
-76
lines changed

9 files changed

+207
-76
lines changed

Orange/modelling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .base import *
22

33
from .knn import *
4+
from .tree import *

Orange/modelling/tree.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from Orange.classification import SklTreeLearner
2+
from Orange.classification import TreeLearner as ClassificationTreeLearner
3+
from Orange.modelling import Fitter
4+
from Orange.regression import TreeLearner as RegressionTreeLearner
5+
from Orange.regression.tree import SklTreeRegressionLearner
6+
from Orange.tree import TreeModel
7+
8+
9+
class SklTreeLearner(Fitter):
10+
name = 'tree'
11+
12+
__fits__ = {'classification': SklTreeLearner,
13+
'regression': SklTreeRegressionLearner}
14+
15+
16+
class TreeLearner(Fitter):
17+
name = 'tree'
18+
19+
__fits__ = {'classification': ClassificationTreeLearner,
20+
'regression': RegressionTreeLearner}
21+
22+
__returns__ = TreeModel

Orange/widgets/classify/owclassificationtree.py

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,12 @@
11
"""General tree learner base widget, and classification tree widget"""
22

3-
from collections import OrderedDict
4-
5-
from AnyQt.QtCore import Qt
6-
73
from Orange.data import Table
8-
from Orange.classification.tree import TreeLearner
9-
from Orange.widgets import gui
10-
from Orange.widgets.settings import Setting
4+
from Orange.modelling.tree import TreeLearner
5+
from Orange.widgets.model.owtree import OWTreeLearner
116
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
127

138

14-
class OWTreeLearner(OWBaseLearner):
15-
"""Base widget for tree learners"""
16-
binary_trees = Setting(True)
17-
limit_min_leaf = Setting(True)
18-
min_leaf = Setting(2)
19-
limit_min_internal = Setting(True)
20-
min_internal = Setting(5)
21-
limit_depth = Setting(True)
22-
max_depth = Setting(100)
23-
24-
spin_boxes = (
25-
("Min. number of instances in leaves: ",
26-
"limit_min_leaf", "min_leaf", 1, 1000),
27-
("Do not split subsets smaller than: ",
28-
"limit_min_internal", "min_internal", 1, 1000),
29-
("Limit the maximal tree depth to: ",
30-
"limit_depth", "max_depth", 1, 1000))
31-
32-
def add_main_layout(self):
33-
box = gui.vBox(self.controlArea, True)
34-
# the checkbox is put into vBox for alignemnt with other checkboxes
35-
gui.checkBox(gui.vBox(box), self, "binary_trees", "Induce binary tree",
36-
callback=self.settings_changed)
37-
for label, check, setting, fromv, tov in self.spin_boxes:
38-
gui.spin(box, self, setting, fromv, tov, label=label, checked=check,
39-
alignment=Qt.AlignRight, callback=self.settings_changed,
40-
checkCallback=self.settings_changed, controlWidth=80)
41-
42-
def learner_kwargs(self):
43-
# Pylint doesn't get our Settings
44-
# pylint: disable=invalid-sequence-index
45-
return dict(
46-
max_depth=(None, self.max_depth)[self.limit_depth],
47-
min_samples_split=(2, self.min_internal)[self.limit_min_internal],
48-
min_samples_leaf=(1, self.min_leaf)[self.limit_min_leaf],
49-
binarize=self.binary_trees,
50-
preprocessors=self.preprocessors)
51-
52-
def create_learner(self):
53-
# pylint: disable=not-callable
54-
return self.LEARNER(**self.learner_kwargs())
55-
56-
def get_learner_parameters(self):
57-
from Orange.canvas.report import plural_w
58-
items = OrderedDict()
59-
items["Pruning"] = ", ".join(s for s, c in (
60-
(plural_w("at least {number} instance{s} in leaves",
61-
self.min_leaf), self.limit_min_leaf),
62-
(plural_w("at least {number} instance{s} in internal nodes",
63-
self.min_internal), self.limit_min_internal),
64-
("maximum depth {}".format(self.max_depth), self.limit_depth)
65-
) if c) or "None"
66-
items["Binary trees"] = ("No", "Yes")[self.binary_trees]
67-
return items
68-
69-
70-
class OWClassificationTree(OWTreeLearner):
9+
class OWTreeLearner(OWTreeLearner):
7110
"""Classification tree algorithm with forward pruning."""
7211

7312
name = "Classification Tree"
@@ -76,9 +15,6 @@ class OWClassificationTree(OWTreeLearner):
7615

7716
LEARNER = TreeLearner
7817

79-
limit_majority = Setting(True)
80-
sufficient_majority = Setting(95)
81-
8218
spin_boxes = \
8319
OWTreeLearner.spin_boxes[:-1] + \
8420
(("Stop when majority reaches [%]: ",
@@ -98,6 +34,10 @@ def get_learner_parameters(self):
9834
"stop splitting when the majority class reaches {} %".format(
9935
self.sufficient_majority)
10036

37+
# Disable the special classification layout to be used when widgets are
38+
# fully merged
39+
add_classification_layout = OWBaseLearner.add_classification_layout
40+
10141

10242
def _test():
10343
import sys

Orange/widgets/classify/tests/test_owclassificationtree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
33
from Orange.base import Model
4-
from Orange.widgets.classify.owclassificationtree import OWClassificationTree
4+
from Orange.widgets.classify.owclassificationtree import OWTreeLearner
55
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
66
ParameterMapping, WidgetLearnerTestMixin)
77

88

99
class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
1010
def setUp(self):
11-
self.widget = self.create_widget(OWClassificationTree,
12-
stored_settings={"auto_apply": False})
11+
self.widget = self.create_widget(
12+
OWTreeLearner, stored_settings={"auto_apply": False})
1313
self.init()
1414
self.model_class = Model
1515

Lines changed: 15 additions & 0 deletions
Loading

Orange/widgets/model/owtree.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Tree learner widget"""
2+
3+
from AnyQt.QtCore import Qt
4+
from collections import OrderedDict
5+
6+
from Orange.data import Table
7+
from Orange.modelling.tree import TreeLearner
8+
from Orange.widgets import gui
9+
from Orange.widgets.settings import Setting
10+
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
11+
12+
13+
class OWTreeLearner(OWBaseLearner):
14+
"""Tree algorithm with forward pruning."""
15+
name = "Tree"
16+
description = "A tree algorithm with forward pruning."
17+
icon = "icons/Tree.svg"
18+
priority = 30
19+
20+
LEARNER = TreeLearner
21+
22+
binary_trees = Setting(True)
23+
limit_min_leaf = Setting(True)
24+
min_leaf = Setting(2)
25+
limit_min_internal = Setting(True)
26+
min_internal = Setting(5)
27+
limit_depth = Setting(True)
28+
max_depth = Setting(100)
29+
30+
# Classification only settings
31+
limit_majority = Setting(True)
32+
sufficient_majority = Setting(95)
33+
34+
spin_boxes = (
35+
("Min. number of instances in leaves: ",
36+
"limit_min_leaf", "min_leaf", 1, 1000),
37+
("Do not split subsets smaller than: ",
38+
"limit_min_internal", "min_internal", 1, 1000),
39+
("Limit the maximal tree depth to: ",
40+
"limit_depth", "max_depth", 1, 1000))
41+
42+
classification_spin_boxes = (
43+
("Stop when majority reaches [%]: ",
44+
"limit_majority", "sufficient_majority", 51, 100),)
45+
46+
def add_main_layout(self):
47+
box = gui.widgetBox(self.controlArea, 'Parameters')
48+
# the checkbox is put into vBox for alignemnt with other checkboxes
49+
gui.checkBox(gui.vBox(box), self, "binary_trees", "Induce binary tree",
50+
callback=self.settings_changed)
51+
for label, check, setting, fromv, tov in self.spin_boxes:
52+
gui.spin(box, self, setting, fromv, tov, label=label,
53+
checked=check, alignment=Qt.AlignRight,
54+
callback=self.settings_changed,
55+
checkCallback=self.settings_changed, controlWidth=80)
56+
57+
def add_classification_layout(self, box):
58+
for label, check, setting, minv, maxv in self.classification_spin_boxes:
59+
gui.spin(box, self, setting, minv, maxv,
60+
label=label, checked=check, alignment=Qt.AlignRight,
61+
callback=self.settings_changed, controlWidth=80,
62+
checkCallback=self.settings_changed)
63+
64+
def learner_kwargs(self):
65+
# Pylint doesn't get our Settings
66+
# pylint: disable=invalid-sequence-index
67+
return dict(
68+
max_depth=(None, self.max_depth)[self.limit_depth],
69+
min_samples_split=(2, self.min_internal)[self.limit_min_internal],
70+
min_samples_leaf=(1, self.min_leaf)[self.limit_min_leaf],
71+
binarize=self.binary_trees,
72+
preprocessors=self.preprocessors,
73+
sufficient_majority=(1, self.sufficient_majority / 100)[
74+
self.limit_majority])
75+
76+
def create_learner(self):
77+
# pylint: disable=not-callable
78+
return self.LEARNER(**self.learner_kwargs())
79+
80+
def get_learner_parameters(self):
81+
from Orange.canvas.report import plural_w
82+
items = OrderedDict()
83+
items["Pruning"] = ", ".join(s for s, c in (
84+
(plural_w("at least {number} instance{s} in leaves",
85+
self.min_leaf), self.limit_min_leaf),
86+
(plural_w("at least {number} instance{s} in internal nodes",
87+
self.min_internal), self.limit_min_internal),
88+
("maximum depth {}".format(self.max_depth), self.limit_depth)
89+
) if c) or "None"
90+
if self.limit_majority:
91+
items["Splitting"] = "Stop splitting when majority reaches %d%% " \
92+
"(classification only)" % \
93+
self.sufficient_majority
94+
items["Binary trees"] = ("No", "Yes")[self.binary_trees]
95+
return items
96+
97+
98+
def _test():
99+
import sys
100+
from AnyQt.QtWidgets import QApplication
101+
102+
a = QApplication(sys.argv)
103+
ow = OWTreeLearner()
104+
d = Table(sys.argv[1] if len(sys.argv) > 1 else 'iris')
105+
ow.set_data(d)
106+
ow.show()
107+
a.exec_()
108+
ow.saveSettings()
109+
110+
if __name__ == "__main__":
111+
_test()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from Orange.base import Model
2+
from Orange.widgets.model.owtree import OWTreeLearner
3+
from Orange.widgets.tests.base import (
4+
DefaultParameterMapping,
5+
ParameterMapping,
6+
WidgetLearnerTestMixin,
7+
WidgetTest,
8+
)
9+
10+
11+
class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
12+
def setUp(self):
13+
self.widget = self.create_widget(
14+
OWTreeLearner, stored_settings={"auto_apply": False})
15+
self.init()
16+
self.model_class = Model
17+
18+
self.parameters = [
19+
ParameterMapping.from_attribute(self.widget, 'max_depth'),
20+
ParameterMapping.from_attribute(
21+
self.widget, 'min_internal', 'min_samples_split'),
22+
ParameterMapping.from_attribute(
23+
self.widget, 'min_leaf', 'min_samples_leaf')]
24+
# NB. sufficient_majority is divided by 100, so it cannot be tested
25+
# like this
26+
27+
self.checks = [sb.gui_element.cbox for sb in self.parameters]
28+
29+
def test_parameters_unchecked(self):
30+
"""Check learner and model for various values of all parameters
31+
when pruning parameters are not checked
32+
"""
33+
for cb in self.checks:
34+
cb.setCheckState(False)
35+
self.parameters = [DefaultParameterMapping(par.name, val)
36+
for par, val in zip(self.parameters, (None, 2, 1))]
37+
self.test_parameters()

Orange/widgets/regression/owregressiontree.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
"""Widget for induction of regression trees"""
22

3-
from Orange.regression.tree import TreeLearner
4-
from Orange.widgets.classify.owclassificationtree import OWTreeLearner
3+
from Orange.modelling.tree import TreeLearner
4+
from Orange.widgets.model.owtree import OWTreeLearner
5+
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
56

67

7-
class OWRegressionTree(OWTreeLearner):
8+
class OWTreeLearner(OWTreeLearner):
89
name = "Regression Tree"
910
description = "A regression tree algorithm with forward pruning."
1011
icon = "icons/RegressionTree.svg"
1112
priority = 30
1213

1314
LEARNER = TreeLearner
1415

16+
# Disable the special classification layout to be used when widgets are
17+
# fully merged
18+
add_classification_layout = OWBaseLearner.add_classification_layout
19+
1520

1621
def _test():
1722
import sys

Orange/widgets/regression/tests/test_owregressiontree.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
33
from Orange.base import Model
4-
from Orange.widgets.regression.owregressiontree import OWRegressionTree
4+
from Orange.widgets.regression.owregressiontree import OWTreeLearner
55
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
66
ParameterMapping, WidgetLearnerTestMixin)
77

88

99
class TestOWRegressionTree(WidgetTest, WidgetLearnerTestMixin):
1010
def setUp(self):
11-
self.widget = self.create_widget(OWRegressionTree,
12-
stored_settings={"auto_apply": False})
11+
self.widget = self.create_widget(
12+
OWTreeLearner, stored_settings={"auto_apply": False})
1313
self.init()
1414
self.model_class = Model
1515
self.parameters = [

0 commit comments

Comments
 (0)