Skip to content

Commit b7df0f7

Browse files
authored
Merge pull request #1474 from VesnaT/ada_boost_fix
[FIX] Fix AdaBoost widgets and add some tests
2 parents 277bc4f + a9b395a commit b7df0f7

File tree

6 files changed

+146
-73
lines changed

6 files changed

+146
-73
lines changed

Orange/widgets/classify/owadaboost.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,43 @@ class OWAdaBoostClassification(OWBaseLearner):
2626
learning_rate = Setting(1.)
2727
algorithm = Setting(0)
2828

29+
DEFAULT_BASE_ESTIMATOR = TreeLearner()
30+
2931
def add_main_layout(self):
3032
box = gui.widgetBox(self.controlArea, "Parameters")
31-
self.base_estimator = TreeLearner()
32-
self.base_label = gui.label(box, self, "Base estimator: " + self.base_estimator.name)
33-
34-
gui.spin(box, self, "n_estimators", 1, 100, label="Number of estimators:",
35-
alignment=Qt.AlignRight, callback=self.settings_changed)
36-
gui.doubleSpin(box, self, "learning_rate", 1e-5, 1.0, 1e-5,
37-
label="Learning rate:", decimals=5, alignment=Qt.AlignRight,
38-
controlWidth=90, callback=self.settings_changed)
33+
self.base_estimator = self.DEFAULT_BASE_ESTIMATOR
34+
self.base_label = gui.label(
35+
box, self, "Base estimator: " + self.base_estimator.name)
36+
37+
self.n_estimators_spin = gui.spin(
38+
box, self, "n_estimators", 1, 100, label="Number of estimators:",
39+
alignment=Qt.AlignRight, callback=self.settings_changed)
40+
self.learning_rate_spin = gui.doubleSpin(
41+
box, self, "learning_rate", 1e-5, 1.0, 1e-5, label="Learning rate:",
42+
decimals=5, alignment=Qt.AlignRight, controlWidth=90,
43+
callback=self.settings_changed)
3944
self.add_specific_parameters(box)
4045

4146
def add_specific_parameters(self, box):
42-
gui.comboBox(box, self, "algorithm", label="Algorithm:",
43-
orientation=Qt.Horizontal, items=self.losses,
44-
callback=self.settings_changed)
47+
self.algorithm_combo = gui.comboBox(
48+
box, self, "algorithm", label="Algorithm:", items=self.losses,
49+
orientation=Qt.Horizontal, callback=self.settings_changed)
4550

4651
def create_learner(self):
4752
return self.LEARNER(
4853
base_estimator=self.base_estimator,
4954
n_estimators=self.n_estimators,
55+
learning_rate=self.learning_rate,
5056
preprocessors=self.preprocessors,
5157
algorithm=self.losses[self.algorithm]
5258
)
5359

54-
def set_base_learner(self, model):
55-
self.base_estimator = model
56-
if self.base_estimator:
57-
self.base_label.setText("Base estimator: " + self.base_estimator.name)
58-
self.apply_button.setDisabled(False)
59-
else:
60-
self.base_label.setText("No base estimator")
61-
self.apply_button.setDisabled(True)
60+
def set_base_learner(self, learner):
61+
self.base_estimator = learner if learner \
62+
else self.DEFAULT_BASE_ESTIMATOR
63+
self.base_label.setText("Base estimator: " + self.base_estimator.name)
64+
if self.auto_apply:
65+
self.apply()
6266

6367
def get_learner_parameters(self):
6468
return (("Base estimator", self.base_estimator),
@@ -69,6 +73,7 @@ def get_learner_parameters(self):
6973
if __name__ == "__main__":
7074
import sys
7175
from PyQt4.QtGui import QApplication
76+
7277
a = QApplication(sys.argv)
7378
ow = OWAdaBoostClassification()
7479
ow.set_data(Table("iris"))
Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,52 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3-
from PyQt4 import QtGui
4-
5-
from Orange.widgets.tests.base import WidgetTest
3+
from Orange.classification import TreeLearner, KNNLearner
64
from Orange.widgets.classify.owadaboost import OWAdaBoostClassification
5+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6+
GuiToParam)
77

88

9-
10-
class TestOWAdaBoostClassification(WidgetTest):
11-
9+
class TestOWAdaBoostClassification(WidgetTest, WidgetLearnerTestMixin):
1210
def setUp(self):
13-
self.widget = self.create_widget(OWAdaBoostClassification)
14-
self.spinners = []
15-
self.spinners.append(self.widget.findChildren(QtGui.QSpinBox)[0])
16-
self.spinners.append(self.widget.findChildren(QtGui.QDoubleSpinBox)[0])
17-
self.combobox_algorithm = self.widget.findChildren(QtGui.QComboBox)[0]
18-
19-
def test_visible_boxes(self):
20-
""" Check if boxes are visible """
21-
self.assertEqual(self.spinners[0].isHidden(), False)
22-
self.assertEqual(self.spinners[1].isHidden(), False)
23-
self.assertEqual(self.combobox_algorithm.isHidden(), False)
24-
25-
def test_parameters_on_output(self):
26-
""" Check right paramaters on output """
27-
self.widget.apply()
28-
learner_params = self.widget.learner.params
29-
self.assertEqual(learner_params.get("n_estimators"), self.spinners[0].value())
30-
self.assertEqual(learner_params.get("learning_rate"), self.spinners[1].value())
31-
self.assertEqual(learner_params.get('algorithm'), self.combobox_algorithm.currentText())
32-
33-
34-
def test_output_algorithm(self):
35-
""" Check if right learning algorithm is on output when we change algorithm """
36-
for index, algorithmName in enumerate(self.widget.losses):
37-
self.combobox_algorithm.setCurrentIndex(index)
38-
self.combobox_algorithm.activated.emit(index)
39-
self.assertEqual(self.combobox_algorithm.currentText(), algorithmName)
40-
self.widget.apply()
41-
self.assertEqual(self.widget.learner.params.get("algorithm").capitalize(),
42-
self.combobox_algorithm.currentText().capitalize())
43-
44-
def test_learner_on_output(self):
45-
""" Check if learner is on output after create widget and apply """
46-
self.widget.apply()
47-
self.assertNotEqual(self.widget.learner, None)
11+
self.widget = self.create_widget(OWAdaBoostClassification,
12+
stored_settings={"auto_apply": False})
13+
self.init()
14+
15+
def combo_set_value(i, x):
16+
x.activated.emit(i)
17+
x.setCurrentIndex(i)
18+
19+
losses = self.widget.losses
20+
nest_spin = self.widget.n_estimators_spin
21+
nest_min_max = [nest_spin.minimum(), nest_spin.maximum()]
22+
rate_spin = self.widget.learning_rate_spin
23+
rate_min_max = [rate_spin.minimum(), rate_spin.maximum()]
24+
self.gui_to_params = [
25+
GuiToParam('algorithm', self.widget.algorithm_combo,
26+
lambda x: x.currentText(),
27+
combo_set_value, losses, list(range(len(losses)))),
28+
GuiToParam('learning_rate', rate_spin, lambda x: x.value(),
29+
lambda i, x: x.setValue(i), rate_min_max, rate_min_max),
30+
GuiToParam('n_estimators', nest_spin, lambda x: x.value(),
31+
lambda i, x: x.setValue(i), nest_min_max, nest_min_max)]
32+
33+
def test_input_learner(self):
34+
"""Check if base learner properly changes with learner on the input"""
35+
max_depth = 2
36+
default_base_est = self.widget.base_estimator
37+
self.assertIsInstance(default_base_est, TreeLearner)
38+
self.assertIsNone(default_base_est.params.get("max_depth"))
39+
self.send_signal("Learner", TreeLearner(max_depth=max_depth))
40+
self.assertEqual(self.widget.base_estimator.params.get("max_depth"),
41+
max_depth)
42+
self.widget.apply_button.button.click()
43+
output_base_est = self.get_output("Learner").params.get("base_estimator")
44+
self.assertEqual(output_base_est.max_depth, max_depth)
45+
46+
def test_input_learner_disconnect(self):
47+
"""Check base learner after disconnecting learner on the input"""
48+
self.send_signal("Learner", KNNLearner())
49+
self.assertIsInstance(self.widget.base_estimator, KNNLearner)
50+
self.send_signal("Learner", None)
51+
self.assertEqual(self.widget.base_estimator,
52+
self.widget.DEFAULT_BASE_ESTIMATOR)

Orange/widgets/regression/owadaboostregression.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from PyQt4.QtCore import Qt
22

33
from Orange.regression.base_regression import LearnerRegression
4+
from Orange.regression import TreeRegressionLearner
45
from Orange.data import Table
56
from Orange.ensembles import SklAdaBoostRegressionLearner
67
from Orange.widgets import gui
@@ -22,15 +23,18 @@ class OWAdaBoostRegression(owadaboost.OWAdaBoostClassification):
2223
losses = ["Linear", "Square", "Exponential"]
2324
loss = Setting(0)
2425

26+
DEFAULT_BASE_ESTIMATOR = TreeRegressionLearner()
27+
2528
def add_specific_parameters(self, box):
26-
gui.comboBox(box, self, "loss", label="Loss:",
27-
orientation=Qt.Horizontal, items=self.losses,
28-
callback=self.settings_changed)
29+
self.loss_combo = gui.comboBox(
30+
box, self, "loss", label="Loss:", orientation=Qt.Horizontal,
31+
items=self.losses, callback=self.settings_changed)
2932

3033
def create_learner(self):
3134
return self.LEARNER(
3235
base_estimator=self.base_estimator,
3336
n_estimators=self.n_estimators,
37+
learning_rate=self.learning_rate,
3438
preprocessors=self.preprocessors,
3539
loss=self.losses[self.loss].lower()
3640
)
@@ -44,6 +48,7 @@ def get_learner_parameters(self):
4448
if __name__ == "__main__":
4549
import sys
4650
from PyQt4.QtGui import QApplication
51+
4752
a = QApplication(sys.argv)
4853
ow = OWAdaBoostRegression()
4954
ow.set_data(Table("housing"))
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Test methods with long descriptive names can omit docstrings
2+
# pylint: disable=missing-docstring
3+
from Orange.regression import TreeRegressionLearner, KNNRegressionLearner
4+
from Orange.widgets.regression.owadaboostregression import OWAdaBoostRegression
5+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6+
GuiToParam)
7+
8+
9+
class TestOWAdaBoostRegression(WidgetTest, WidgetLearnerTestMixin):
10+
def setUp(self):
11+
self.widget = self.create_widget(OWAdaBoostRegression,
12+
stored_settings={"auto_apply": False})
13+
self.init()
14+
15+
def combo_set_value(i, x):
16+
x.activated.emit(i)
17+
x.setCurrentIndex(i)
18+
19+
losses = [loss.lower() for loss in self.widget.losses]
20+
nest_spin = self.widget.n_estimators_spin
21+
nest_min_max = [nest_spin.minimum(), nest_spin.maximum()]
22+
rate_spin = self.widget.learning_rate_spin
23+
rate_min_max = [rate_spin.minimum(), rate_spin.maximum()]
24+
self.gui_to_params = [
25+
GuiToParam('loss', self.widget.loss_combo,
26+
lambda x: x.currentText().lower(),
27+
combo_set_value, losses, list(range(len(losses)))),
28+
GuiToParam('learning_rate', rate_spin, lambda x: x.value(),
29+
lambda i, x: x.setValue(i), rate_min_max, rate_min_max),
30+
GuiToParam('n_estimators', nest_spin, lambda x: x.value(),
31+
lambda i, x: x.setValue(i), nest_min_max, nest_min_max)]
32+
33+
def test_input_learner(self):
34+
"""Check if base learner properly changes with learner on the input"""
35+
max_depth = 2
36+
default_base_est = self.widget.base_estimator
37+
self.assertIsInstance(default_base_est, TreeRegressionLearner)
38+
self.assertIsNone(default_base_est.params.get("max_depth"))
39+
self.send_signal("Learner", TreeRegressionLearner(max_depth=max_depth))
40+
self.assertEqual(self.widget.base_estimator.params.get("max_depth"),
41+
max_depth)
42+
self.widget.apply_button.button.click()
43+
output_base_est = self.get_output("Learner").params.get("base_estimator")
44+
self.assertEqual(output_base_est.max_depth, max_depth)
45+
46+
def test_input_learner_disconnect(self):
47+
"""Check base learner after disconnecting learner on the input"""
48+
self.send_signal("Learner", KNNRegressionLearner())
49+
self.assertIsInstance(self.widget.base_estimator, KNNRegressionLearner)
50+
self.send_signal("Learner", None)
51+
self.assertEqual(self.widget.base_estimator,
52+
self.widget.DEFAULT_BASE_ESTIMATOR)

Orange/widgets/tests/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22
from collections import namedtuple
3-
from PyQt4 import QtGui
43

54
from PyQt4.QtGui import QApplication
65
import sip
@@ -259,18 +258,17 @@ def test_output_model(self):
259258
def test_output_learner_name(self):
260259
"""Check if learner's name properly changes"""
261260
new_name = "Learner Name"
262-
name_line_edit = self.widget.findChildren(QtGui.QLineEdit)[0]
263261
self.widget.apply_button.button.click()
264-
self.assertEqual(self.widget.learner.name, name_line_edit.text())
265-
name_line_edit.setText(new_name)
262+
self.assertEqual(self.widget.learner.name,
263+
self.widget.name_line_edit.text())
264+
self.widget.name_line_edit.setText(new_name)
266265
self.widget.apply_button.button.click()
267266
self.assertEqual(self.get_output("Learner").name, new_name)
268267

269268
def test_output_model_name(self):
270269
"""Check if model's name properly changes"""
271270
new_name = "Model Name"
272-
name_line_edit = self.widget.findChildren(QtGui.QLineEdit)[0]
273-
name_line_edit.setText(new_name)
271+
self.widget.name_line_edit.setText(new_name)
274272
self.send_signal("Data", self.data)
275273
self.widget.apply_button.button.click()
276274
self.assertEqual(self.get_output(self.model_name).name, new_name)
@@ -294,3 +292,11 @@ def test_parameters(self):
294292
param = self.widget.learner.params.get(element.name)
295293
self.assertEqual(param, element.get(element.gui_el))
296294
self.assertEqual(param, val)
295+
param = self.get_output("Learner").params.get(element.name)
296+
self.assertEqual(param, val)
297+
model = self.get_output(self.model_name)
298+
if model is not None:
299+
self.assertEqual(model.params.get(element.name), val)
300+
else:
301+
self.assertIn(self.widget.DATA_ERROR_ID,
302+
self.widget.widgetState.get("Error"))

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,10 @@ def add_main_layout(self):
225225
pass
226226

227227
def add_learner_name_widget(self):
228-
gui.lineEdit(self.controlArea, self, 'learner_name', box='Name',
229-
tooltip='The name will identify this model in other widgets',
230-
orientation=Qt.Horizontal,
231-
callback=lambda: self.apply())
228+
self.name_line_edit = gui.lineEdit(
229+
self.controlArea, self, 'learner_name', box='Name',
230+
tooltip='The name will identify this model in other widgets',
231+
orientation=Qt.Horizontal, callback=lambda: self.apply())
232232

233233
def add_bottom_buttons(self):
234234
box = gui.hBox(self.controlArea, True)

0 commit comments

Comments
 (0)