Skip to content

Commit 6965c7a

Browse files
authored
Merge pull request #1477 from VesnaT/rf_fix
[FIX] OWRandomForest: Fix, refactor and widget tests
2 parents 8dea6c6 + d409fb2 commit 6965c7a

File tree

3 files changed

+135
-74
lines changed

3 files changed

+135
-74
lines changed

Orange/widgets/classify/owrandomforest.py

Lines changed: 33 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# -*- coding: utf-8 -*-
22
from PyQt4 import QtGui
3-
from PyQt4.QtGui import QLabel, QGridLayout
4-
from PyQt4.QtCore import Qt
53

64
from Orange.data import Table
75
from Orange.classification.random_forest import RandomForestLearner
@@ -29,72 +27,32 @@ class OWRandomForest(OWBaseLearner):
2927
index_output = settings.Setting(0)
3028

3129
def add_main_layout(self):
32-
form = QGridLayout()
33-
basic_box = gui.widgetBox(
34-
self.controlArea, "Basic Properties", orientation=form)
35-
36-
form.addWidget(QLabel(self.tr("Number of trees: ")),
37-
0, 0, Qt.AlignLeft)
38-
spin = gui.spin(basic_box, self, "n_estimators", minv=1, maxv=10000,
39-
callback=self.settings_changed, addToLayout=False,
40-
controlWidth=50)
41-
form.addWidget(spin, 0, 1, Qt.AlignRight)
42-
43-
max_features_cb = gui.checkBox(
44-
basic_box, self, "use_max_features",
45-
callback=self.settings_changed, addToLayout=False,
46-
label="Number of attributes considered at each split: ")
47-
48-
max_features_spin = gui.spin(
49-
basic_box, self, "max_features", 2, 50, addToLayout=False,
50-
callback=self.settings_changed, controlWidth=50)
51-
52-
form.addWidget(max_features_cb, 1, 0, Qt.AlignLeft)
53-
form.addWidget(max_features_spin, 1, 1, Qt.AlignRight)
54-
55-
random_state_cb = gui.checkBox(
56-
basic_box, self, "use_random_state", callback=self.settings_changed,
57-
addToLayout=False, label="Fixed seed for random generator: ")
58-
random_state_spin = gui.spin(
59-
basic_box, self, "random_state", 0, 2 ** 31 - 1, addToLayout=False,
60-
callback=self.settings_changed, controlWidth=50)
61-
62-
form.addWidget(random_state_cb, 2, 0, Qt.AlignLeft)
63-
form.addWidget(random_state_spin, 2, 1, Qt.AlignRight)
64-
self._max_features_spin = max_features_spin
65-
self._random_state_spin = random_state_spin
66-
67-
# Growth control
68-
form = QGridLayout()
69-
growth_box = gui.widgetBox(
70-
self.controlArea, "Growth Control", orientation=form)
71-
72-
max_depth_cb = gui.checkBox(
73-
growth_box, self, "use_max_depth",
30+
box = gui.vBox(self.controlArea, 'Basic Properties')
31+
self.n_estimators_spin = gui.spin(
32+
box, self, "n_estimators", minv=1, maxv=10000, controlWidth=50,
33+
label="Number of trees: ", callback=self.settings_changed)
34+
self.max_features_spin = gui.spin(
35+
box, self, "max_features", 2, 50, controlWidth=50,
36+
label="Number of attributes considered at each split: ",
37+
callback=self.settings_changed, checked="use_max_features",
38+
checkCallback=self.settings_changed)
39+
self.random_state_spin = gui.spin(
40+
box, self, "random_state", 0, 2 ** 31 - 1, controlWidth=50,
41+
label="Fixed seed for random generator: ",
42+
callback=self.settings_changed, checked="use_random_state",
43+
checkCallback=self.settings_changed)
44+
45+
box = gui.vBox(self.controlArea, "Growth Control")
46+
self.max_depth_spin = gui.spin(
47+
box, self, "max_depth", 1, 50, controlWidth=50,
7448
label="Limit depth of individual trees: ",
75-
callback=self.settings_changed,
76-
addToLayout=False)
77-
78-
max_depth_spin = gui.spin(
79-
growth_box, self, "max_depth", 1, 50, addToLayout=False,
80-
callback=self.settings_changed)
81-
82-
form.addWidget(max_depth_cb, 3, 0, Qt.AlignLeft)
83-
form.addWidget(max_depth_spin, 3, 1, Qt.AlignRight)
84-
85-
min_samples_split_cb = gui.checkBox(
86-
growth_box, self, "use_min_samples_split",
49+
callback=self.settings_changed, checked="use_max_depth",
50+
checkCallback=self.settings_changed)
51+
self.min_samples_split_spin = gui.spin(
52+
box, self, "min_samples_split", 1, 1000, controlWidth=50,
8753
label="Do not split subsets smaller than: ",
88-
callback=self.settings_changed, addToLayout=False)
89-
90-
min_samples_split_spin = gui.spin(
91-
growth_box, self, "min_samples_split", 1, 1000, addToLayout=False,
92-
callback=self.settings_changed)
93-
94-
form.addWidget(min_samples_split_cb, 4, 0, Qt.AlignLeft)
95-
form.addWidget(min_samples_split_spin, 4, 1, Qt.AlignRight)
96-
self._max_depth_spin = max_depth_spin
97-
self._min_samples_split_spin = min_samples_split_spin
54+
callback=self.settings_changed, checked="use_min_samples_split",
55+
checkCallback=self.settings_changed)
9856

9957
# Index on the output
10058
# gui.doubleSpin(self.controlArea, self, "index_output", 0, 10000, 1,
@@ -113,12 +71,15 @@ def create_learner(self):
11371

11472
return self.LEARNER(preprocessors=self.preprocessors, **common_args)
11573

116-
def settings_changed(self):
117-
super().settings_changed()
118-
self._max_features_spin.setEnabled(self.use_max_features)
119-
self._random_state_spin.setEnabled(self.use_random_state)
120-
self._max_depth_spin.setEnabled(self.use_max_depth)
121-
self._min_samples_split_spin.setEnabled(self.use_min_samples_split)
74+
def check_data(self):
75+
if super().check_data():
76+
n_features = len(self.data.domain.attributes)
77+
if self.use_max_features and self.max_features > n_features:
78+
self.error(self.DATA_ERROR_ID,
79+
"Number of splitting attributes should "
80+
"be smaller than number of features.")
81+
self.valid_data = False
82+
return self.valid_data
12283

12384
def get_learner_parameters(self):
12485
"""Called by send report to list the parameters of the learner."""
Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,61 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
33
from Orange.widgets.classify.owrandomforest import OWRandomForest
4-
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin
4+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
5+
GuiToParam)
56

67

78
class TestOWRandomForest(WidgetTest, WidgetLearnerTestMixin):
89
def setUp(self):
910
self.widget = self.create_widget(OWRandomForest,
1011
stored_settings={"auto_apply": False})
1112
self.init()
13+
n_est_spin = self.widget.n_estimators_spin
14+
max_f_spin = self.widget.max_features_spin[1]
15+
rs_spin = self.widget.random_state_spin[1]
16+
max_d_spin = self.widget.max_depth_spin[1]
17+
min_s_spin = self.widget.min_samples_split_spin[1]
18+
n_est_min_max = [n_est_spin.minimum() * 10, n_est_spin.minimum()]
19+
min_s_min_max = [min_s_spin.minimum(), min_s_spin.maximum()]
20+
self.gui_to_params = [
21+
GuiToParam("n_estimators", n_est_spin, lambda x: x.value(),
22+
lambda i, x: x.setValue(i), n_est_min_max, n_est_min_max),
23+
GuiToParam("max_features", max_f_spin, lambda x: "auto",
24+
lambda i, x: x.setValue(i), ["auto"], [0]),
25+
GuiToParam("random_state", rs_spin, lambda x: None,
26+
lambda i, x: x.setValue(i), [None], [0]),
27+
GuiToParam("max_depth", max_d_spin, lambda x: None,
28+
lambda i, x: x.setValue(i), [None], [0]),
29+
GuiToParam("min_samples_split", min_s_spin, lambda x: x.value(),
30+
lambda i, x: x.setValue(i), min_s_min_max, min_s_min_max)]
31+
32+
def test_parameters_checked(self):
33+
"""Check learner and model for various values of all parameters
34+
when all properties are checked
35+
"""
36+
self.widget.max_features_spin[0].click()
37+
self.widget.random_state_spin[0].click()
38+
self.widget.max_depth_spin[0].click()
39+
for j in range(1, 4):
40+
el = self.gui_to_params[j]
41+
el_min_max = [el.gui_el.minimum(), el.gui_el.maximum()]
42+
self.gui_to_params[j] = GuiToParam(
43+
el.name, el.gui_el, lambda x: x.value(),
44+
lambda i, x: x.setValue(i), el_min_max, el_min_max)
45+
self.test_parameters()
46+
# FIXME: checkboxes are reset to default, since the widget settings were saved
47+
self.widget.max_features_spin[0].setCheckState(False)
48+
self.widget.random_state_spin[0].setCheckState(False)
49+
self.widget.max_depth_spin[0].setCheckState(False)
50+
51+
def test_parameters_unchecked(self):
52+
"""Check learner and model for various values of all parameters
53+
when properties are not checked
54+
"""
55+
self.widget.min_samples_split_spin[0].click()
56+
el = self.gui_to_params[4]
57+
self.gui_to_params[4] = GuiToParam(el.name, el.gui_el, lambda x: 2,
58+
lambda i, x: x.setValue(i), [2], [0])
59+
self.test_parameters()
60+
# FIXME: checkboxes are reset to default, since the widget settings were saved
61+
self.widget.min_samples_split_spin[0].setCheckState(True)

Orange/widgets/regression/tests/test_owrandomforestregression.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,61 @@
22
# pylint: disable=missing-docstring
33
from Orange.widgets.regression.owrandomforestregression import \
44
OWRandomForestRegression
5-
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin
5+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6+
GuiToParam)
67

78

89
class TestOWRandomForestRegression(WidgetTest, WidgetLearnerTestMixin):
910
def setUp(self):
1011
self.widget = self.create_widget(OWRandomForestRegression,
1112
stored_settings={"auto_apply": False})
1213
self.init()
14+
n_est_spin = self.widget.n_estimators_spin
15+
max_f_spin = self.widget.max_features_spin[1]
16+
rs_spin = self.widget.random_state_spin[1]
17+
max_d_spin = self.widget.max_depth_spin[1]
18+
min_s_spin = self.widget.min_samples_split_spin[1]
19+
n_est_min_max = [n_est_spin.minimum() * 10, n_est_spin.minimum()]
20+
min_s_min_max = [min_s_spin.minimum(), min_s_spin.maximum()]
21+
self.gui_to_params = [
22+
GuiToParam("n_estimators", n_est_spin, lambda x: x.value(),
23+
lambda i, x: x.setValue(i), n_est_min_max, n_est_min_max),
24+
GuiToParam("max_features", max_f_spin, lambda x: "auto",
25+
lambda i, x: x.setValue(i), ["auto"], [0]),
26+
GuiToParam("random_state", rs_spin, lambda x: None,
27+
lambda i, x: x.setValue(i), [None], [0]),
28+
GuiToParam("max_depth", max_d_spin, lambda x: None,
29+
lambda i, x: x.setValue(i), [None], [0]),
30+
GuiToParam("min_samples_split", min_s_spin, lambda x: x.value(),
31+
lambda i, x: x.setValue(i), min_s_min_max, min_s_min_max)]
32+
33+
def test_parameters_checked(self):
34+
"""Check learner and model for various values of all parameters
35+
when all properties are checked
36+
"""
37+
self.widget.max_features_spin[0].click()
38+
self.widget.random_state_spin[0].click()
39+
self.widget.max_depth_spin[0].click()
40+
for j in range(1, 4):
41+
el = self.gui_to_params[j]
42+
el_min_max = [el.gui_el.minimum(), el.gui_el.maximum()]
43+
self.gui_to_params[j] = GuiToParam(
44+
el.name, el.gui_el, lambda x: x.value(),
45+
lambda i, x: x.setValue(i), el_min_max, el_min_max)
46+
self.test_parameters()
47+
# FIXME: checkboxes are reset to default, since the widget settings were saved
48+
self.widget.max_features_spin[0].setCheckState(False)
49+
self.widget.random_state_spin[0].setCheckState(False)
50+
self.widget.max_depth_spin[0].setCheckState(False)
51+
52+
def test_parameters_unchecked(self):
53+
"""Check learner and model for various values of all parameters
54+
when properties are not checked
55+
"""
56+
self.widget.min_samples_split_spin[0].click()
57+
el = self.gui_to_params[4]
58+
self.gui_to_params[4] = GuiToParam(el.name, el.gui_el, lambda x: 2,
59+
lambda i, x: x.setValue(i), [2], [0])
60+
self.test_parameters()
61+
# FIXME: checkboxes are reset to default, since the widget settings were saved
62+
self.widget.min_samples_split_spin[0].setCheckState(True)

0 commit comments

Comments
 (0)