Skip to content

Commit a9b395a

Browse files
committed
TestOWAdaBoostRegression: Add tests
1 parent 30f7e83 commit a9b395a

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Test methods with long descriptive names can omit docstrings
2+
# pylint: disable=missing-docstring
3+
from Orange.regression import TreeRegressionLearner, KNNRegressionLearner
4+
from Orange.widgets.regression.owadaboostregression import OWAdaBoostRegression
5+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6+
GuiToParam)
7+
8+
9+
class TestOWAdaBoostRegression(WidgetTest, WidgetLearnerTestMixin):
10+
def setUp(self):
11+
self.widget = self.create_widget(OWAdaBoostRegression,
12+
stored_settings={"auto_apply": False})
13+
self.init()
14+
15+
def combo_set_value(i, x):
16+
x.activated.emit(i)
17+
x.setCurrentIndex(i)
18+
19+
losses = [loss.lower() for loss in self.widget.losses]
20+
nest_spin = self.widget.n_estimators_spin
21+
nest_min_max = [nest_spin.minimum(), nest_spin.maximum()]
22+
rate_spin = self.widget.learning_rate_spin
23+
rate_min_max = [rate_spin.minimum(), rate_spin.maximum()]
24+
self.gui_to_params = [
25+
GuiToParam('loss', self.widget.loss_combo,
26+
lambda x: x.currentText().lower(),
27+
combo_set_value, losses, list(range(len(losses)))),
28+
GuiToParam('learning_rate', rate_spin, lambda x: x.value(),
29+
lambda i, x: x.setValue(i), rate_min_max, rate_min_max),
30+
GuiToParam('n_estimators', nest_spin, lambda x: x.value(),
31+
lambda i, x: x.setValue(i), nest_min_max, nest_min_max)]
32+
33+
def test_input_learner(self):
34+
"""Check if base learner properly changes with learner on the input"""
35+
max_depth = 2
36+
default_base_est = self.widget.base_estimator
37+
self.assertIsInstance(default_base_est, TreeRegressionLearner)
38+
self.assertIsNone(default_base_est.params.get("max_depth"))
39+
self.send_signal("Learner", TreeRegressionLearner(max_depth=max_depth))
40+
self.assertEqual(self.widget.base_estimator.params.get("max_depth"),
41+
max_depth)
42+
self.widget.apply_button.button.click()
43+
output_base_est = self.get_output("Learner").params.get("base_estimator")
44+
self.assertEqual(output_base_est.max_depth, max_depth)
45+
46+
def test_input_learner_disconnect(self):
47+
"""Check base learner after disconnecting learner on the input"""
48+
self.send_signal("Learner", KNNRegressionLearner())
49+
self.assertIsInstance(self.widget.base_estimator, KNNRegressionLearner)
50+
self.send_signal("Learner", None)
51+
self.assertEqual(self.widget.base_estimator,
52+
self.widget.DEFAULT_BASE_ESTIMATOR)

0 commit comments

Comments
 (0)