Skip to content

Commit 11c3f11

Browse files
authored
Merge pull request #1787 from pavlin-policar/improve-ada-boost-widget
[FIX] Improve ada boost widget
2 parents 751913a + fa6c972 commit 11c3f11

File tree

5 files changed

+91
-25
lines changed

5 files changed

+91
-25
lines changed

Orange/widgets/classify/owadaboost.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from AnyQt.QtCore import Qt
22

3+
from Orange.classification import SklTreeLearner
34
from Orange.classification.base_classification import LearnerClassification
45
from Orange.data import Table
5-
from Orange.classification import SklTreeLearner
66
from Orange.ensembles import SklAdaBoostLearner
77
from Orange.widgets import gui
88
from Orange.widgets.settings import Setting
99
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
10+
from Orange.widgets.widget import Msg
1011

1112

1213
class OWAdaBoostClassification(OWBaseLearner):
@@ -28,6 +29,9 @@ class OWAdaBoostClassification(OWBaseLearner):
2829

2930
DEFAULT_BASE_ESTIMATOR = SklTreeLearner()
3031

32+
class Error(OWBaseLearner.Error):
33+
no_weight_support = Msg('The base learner does not support weights.')
34+
3135
def add_main_layout(self):
3236
box = gui.widgetBox(self.controlArea, "Parameters")
3337
self.base_estimator = self.DEFAULT_BASE_ESTIMATOR
@@ -36,10 +40,11 @@ def add_main_layout(self):
3640

3741
self.n_estimators_spin = gui.spin(
3842
box, self, "n_estimators", 1, 100, label="Number of estimators:",
39-
alignment=Qt.AlignRight, callback=self.settings_changed)
43+
alignment=Qt.AlignRight, controlWidth=80,
44+
callback=self.settings_changed)
4045
self.learning_rate_spin = gui.doubleSpin(
4146
box, self, "learning_rate", 1e-5, 1.0, 1e-5, label="Learning rate:",
42-
decimals=5, alignment=Qt.AlignRight, controlWidth=90,
47+
decimals=5, alignment=Qt.AlignRight, controlWidth=80,
4348
callback=self.settings_changed)
4449
self.add_specific_parameters(box)
4550

@@ -49,18 +54,25 @@ def add_specific_parameters(self, box):
4954
orientation=Qt.Horizontal, callback=self.settings_changed)
5055

5156
def create_learner(self):
57+
if self.base_estimator is None:
58+
return None
5259
return self.LEARNER(
5360
base_estimator=self.base_estimator,
5461
n_estimators=self.n_estimators,
5562
learning_rate=self.learning_rate,
5663
preprocessors=self.preprocessors,
57-
algorithm=self.losses[self.algorithm]
58-
)
64+
algorithm=self.losses[self.algorithm])
5965

6066
def set_base_learner(self, learner):
61-
self.base_estimator = learner if learner \
62-
else self.DEFAULT_BASE_ESTIMATOR
63-
self.base_label.setText("Base estimator: " + self.base_estimator.name)
67+
self.Error.no_weight_support.clear()
68+
if learner and not learner.supports_weights:
69+
# Clear the error and reset to default base learner
70+
self.Error.no_weight_support()
71+
self.base_estimator = None
72+
self.base_label.setText("Base estimator: INVALID")
73+
else:
74+
self.base_estimator = learner or self.DEFAULT_BASE_ESTIMATOR
75+
self.base_label.setText("Base estimator: " + self.base_estimator.name)
6476
if self.auto_apply:
6577
self.apply()
6678

@@ -76,7 +88,7 @@ def get_learner_parameters(self):
7688

7789
a = QApplication(sys.argv)
7890
ow = OWAdaBoostClassification()
79-
ow.set_data(Table("iris"))
91+
ow.set_data(Table(sys.argv[1] if len(sys.argv) > 1 else 'iris'))
8092
ow.show()
8193
a.exec_()
8294
ow.saveSettings()

Orange/widgets/classify/tests/test_owadaboostclassification.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3-
from Orange.classification import SklTreeLearner, KNNLearner
3+
from Orange.classification import (
4+
KNNLearner, RandomForestLearner, SklTreeLearner
5+
)
46
from Orange.widgets.classify.owadaboost import OWAdaBoostClassification
5-
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6-
ParameterMapping)
7+
from Orange.widgets.tests.base import (
8+
WidgetTest, WidgetLearnerTestMixin, ParameterMapping
9+
)
710

811

912
class TestOWAdaBoostClassification(WidgetTest, WidgetLearnerTestMixin):
@@ -30,10 +33,32 @@ def test_input_learner(self):
3033
output_base_est = self.get_output("Learner").params.get("base_estimator")
3134
self.assertEqual(output_base_est.max_depth, max_depth)
3235

36+
def test_input_learner_that_does_not_support_sample_weights(self):
37+
self.send_signal("Learner", KNNLearner())
38+
self.assertNotIsInstance(self.widget.base_estimator, KNNLearner)
39+
self.assertIsNone(self.widget.base_estimator)
40+
self.assertTrue(self.widget.Error.no_weight_support.is_shown())
41+
42+
def test_error_message_cleared_when_valid_learner_on_input(self):
43+
# Disconnecting an invalid learner should use the default one and hide
44+
# the error
45+
self.send_signal("Learner", KNNLearner())
46+
self.send_signal('Learner', None)
47+
self.assertFalse(
48+
self.widget.Error.no_weight_support.is_shown(),
49+
'Error message was not hidden on input disconnect')
50+
# Connecting a valid learner should also reset the error message
51+
self.send_signal("Learner", KNNLearner())
52+
self.send_signal('Learner', RandomForestLearner())
53+
self.assertFalse(
54+
self.widget.Error.no_weight_support.is_shown(),
55+
'Error message was not hidden when a valid learner appeared on '
56+
'input')
57+
3358
def test_input_learner_disconnect(self):
3459
"""Check base learner after disconnecting learner on the input"""
35-
self.send_signal("Learner", KNNLearner())
36-
self.assertIsInstance(self.widget.base_estimator, KNNLearner)
60+
self.send_signal("Learner", RandomForestLearner())
61+
self.assertIsInstance(self.widget.base_estimator, RandomForestLearner)
3762
self.send_signal("Learner", None)
3863
self.assertEqual(self.widget.base_estimator,
3964
self.widget.DEFAULT_BASE_ESTIMATOR)

Orange/widgets/regression/owadaboostregression.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from AnyQt.QtCore import Qt
22

3-
from Orange.regression.base_regression import LearnerRegression
4-
from Orange.regression import SklTreeRegressionLearner
53
from Orange.data import Table
64
from Orange.ensembles import SklAdaBoostRegressionLearner
5+
from Orange.regression import SklTreeRegressionLearner
6+
from Orange.regression.base_regression import LearnerRegression
77
from Orange.widgets import gui
88
from Orange.widgets.classify import owadaboost
99
from Orange.widgets.settings import Setting
@@ -36,8 +36,7 @@ def create_learner(self):
3636
n_estimators=self.n_estimators,
3737
learning_rate=self.learning_rate,
3838
preprocessors=self.preprocessors,
39-
loss=self.losses[self.loss].lower()
40-
)
39+
loss=self.losses[self.loss].lower())
4140

4241
def get_learner_parameters(self):
4342
return (("Base estimator", self.base_estimator),
@@ -51,7 +50,7 @@ def get_learner_parameters(self):
5150

5251
a = QApplication(sys.argv)
5352
ow = OWAdaBoostRegression()
54-
ow.set_data(Table("housing"))
53+
ow.set_data(Table(sys.argv[1] if len(sys.argv) > 1 else 'housing'))
5554
ow.show()
5655
a.exec_()
5756
ow.saveSettings()

Orange/widgets/regression/tests/test_owadaboostregression.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3-
from Orange.regression import SklTreeRegressionLearner, KNNRegressionLearner
3+
from Orange.regression import (
4+
SklTreeRegressionLearner,
5+
KNNRegressionLearner,
6+
RandomForestRegressionLearner
7+
)
48
from Orange.widgets.regression.owadaboostregression import OWAdaBoostRegression
5-
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6-
ParameterMapping)
9+
from Orange.widgets.tests.base import (
10+
WidgetTest, WidgetLearnerTestMixin, ParameterMapping
11+
)
712

813

914
class TestOWAdaBoostRegression(WidgetTest, WidgetLearnerTestMixin):
@@ -31,10 +36,34 @@ def test_input_learner(self):
3136
output_base_est = self.get_output("Learner").params.get("base_estimator")
3237
self.assertEqual(output_base_est.max_depth, max_depth)
3338

39+
def test_input_learner_that_does_not_support_sample_weights(self):
40+
self.send_signal("Learner", KNNRegressionLearner())
41+
self.assertNotIsInstance(
42+
self.widget.base_estimator, KNNRegressionLearner)
43+
self.assertIsNone(self.widget.base_estimator)
44+
self.assertTrue(self.widget.Error.no_weight_support.is_shown())
45+
46+
def test_error_message_cleared_when_valid_learner_on_input(self):
47+
# Disconnecting an invalid learner should use the default one and hide
48+
# the error
49+
self.send_signal("Learner", KNNRegressionLearner())
50+
self.send_signal('Learner', None)
51+
self.assertFalse(
52+
self.widget.Error.no_weight_support.is_shown(),
53+
'Error message was not hidden on input disconnect')
54+
# Connecting a valid learner should also reset the error message
55+
self.send_signal("Learner", KNNRegressionLearner())
56+
self.send_signal('Learner', RandomForestRegressionLearner())
57+
self.assertFalse(
58+
self.widget.Error.no_weight_support.is_shown(),
59+
'Error message was not hidden when a valid learner appeared on '
60+
'input')
61+
3462
def test_input_learner_disconnect(self):
3563
"""Check base learner after disconnecting learner on the input"""
36-
self.send_signal("Learner", KNNRegressionLearner())
37-
self.assertIsInstance(self.widget.base_estimator, KNNRegressionLearner)
64+
self.send_signal("Learner", RandomForestRegressionLearner())
65+
self.assertIsInstance(
66+
self.widget.base_estimator, RandomForestRegressionLearner)
3867
self.send_signal("Learner", None)
3968
self.assertEqual(self.widget.base_estimator,
4069
self.widget.DEFAULT_BASE_ESTIMATOR)

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,8 @@ def apply(self):
165165

166166
def update_learner(self):
167167
self.learner = self.create_learner()
168-
self.learner.name = self.learner_name
168+
if self.learner is not None:
169+
self.learner.name = self.learner_name
169170
self.send("Learner", self.learner)
170171
self.outdated_settings = False
171172
self.Warning.outdated_learner.clear()

0 commit comments

Comments
 (0)