Skip to content

Commit 30f7e83

Browse files
committed
TestOWAdaBoostClassification: Modify tests
1 parent 24abe42 commit 30f7e83

File tree

1 file changed

+46
-41
lines changed

1 file changed

+46
-41
lines changed
Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,52 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3-
from PyQt4 import QtGui
4-
5-
from Orange.widgets.tests.base import WidgetTest
3+
from Orange.classification import TreeLearner, KNNLearner
64
from Orange.widgets.classify.owadaboost import OWAdaBoostClassification
5+
from Orange.widgets.tests.base import (WidgetTest, WidgetLearnerTestMixin,
6+
GuiToParam)
77

88

9-
10-
class TestOWAdaBoostClassification(WidgetTest):
11-
9+
class TestOWAdaBoostClassification(WidgetTest, WidgetLearnerTestMixin):
1210
def setUp(self):
13-
self.widget = self.create_widget(OWAdaBoostClassification)
14-
self.spinners = []
15-
self.spinners.append(self.widget.findChildren(QtGui.QSpinBox)[0])
16-
self.spinners.append(self.widget.findChildren(QtGui.QDoubleSpinBox)[0])
17-
self.combobox_algorithm = self.widget.findChildren(QtGui.QComboBox)[0]
18-
19-
def test_visible_boxes(self):
20-
""" Check if boxes are visible """
21-
self.assertEqual(self.spinners[0].isHidden(), False)
22-
self.assertEqual(self.spinners[1].isHidden(), False)
23-
self.assertEqual(self.combobox_algorithm.isHidden(), False)
24-
25-
def test_parameters_on_output(self):
26-
""" Check right paramaters on output """
27-
self.widget.apply()
28-
learner_params = self.widget.learner.params
29-
self.assertEqual(learner_params.get("n_estimators"), self.spinners[0].value())
30-
self.assertEqual(learner_params.get("learning_rate"), self.spinners[1].value())
31-
self.assertEqual(learner_params.get('algorithm'), self.combobox_algorithm.currentText())
32-
33-
34-
def test_output_algorithm(self):
35-
""" Check if right learning algorithm is on output when we change algorithm """
36-
for index, algorithmName in enumerate(self.widget.losses):
37-
self.combobox_algorithm.setCurrentIndex(index)
38-
self.combobox_algorithm.activated.emit(index)
39-
self.assertEqual(self.combobox_algorithm.currentText(), algorithmName)
40-
self.widget.apply()
41-
self.assertEqual(self.widget.learner.params.get("algorithm").capitalize(),
42-
self.combobox_algorithm.currentText().capitalize())
43-
44-
def test_learner_on_output(self):
45-
""" Check if learner is on output after create widget and apply """
46-
self.widget.apply()
47-
self.assertNotEqual(self.widget.learner, None)
11+
self.widget = self.create_widget(OWAdaBoostClassification,
12+
stored_settings={"auto_apply": False})
13+
self.init()
14+
15+
def combo_set_value(i, x):
16+
x.activated.emit(i)
17+
x.setCurrentIndex(i)
18+
19+
losses = self.widget.losses
20+
nest_spin = self.widget.n_estimators_spin
21+
nest_min_max = [nest_spin.minimum(), nest_spin.maximum()]
22+
rate_spin = self.widget.learning_rate_spin
23+
rate_min_max = [rate_spin.minimum(), rate_spin.maximum()]
24+
self.gui_to_params = [
25+
GuiToParam('algorithm', self.widget.algorithm_combo,
26+
lambda x: x.currentText(),
27+
combo_set_value, losses, list(range(len(losses)))),
28+
GuiToParam('learning_rate', rate_spin, lambda x: x.value(),
29+
lambda i, x: x.setValue(i), rate_min_max, rate_min_max),
30+
GuiToParam('n_estimators', nest_spin, lambda x: x.value(),
31+
lambda i, x: x.setValue(i), nest_min_max, nest_min_max)]
32+
33+
def test_input_learner(self):
34+
"""Check if base learner properly changes with learner on the input"""
35+
max_depth = 2
36+
default_base_est = self.widget.base_estimator
37+
self.assertIsInstance(default_base_est, TreeLearner)
38+
self.assertIsNone(default_base_est.params.get("max_depth"))
39+
self.send_signal("Learner", TreeLearner(max_depth=max_depth))
40+
self.assertEqual(self.widget.base_estimator.params.get("max_depth"),
41+
max_depth)
42+
self.widget.apply_button.button.click()
43+
output_base_est = self.get_output("Learner").params.get("base_estimator")
44+
self.assertEqual(output_base_est.max_depth, max_depth)
45+
46+
def test_input_learner_disconnect(self):
47+
"""Check base learner after disconnecting learner on the input"""
48+
self.send_signal("Learner", KNNLearner())
49+
self.assertIsInstance(self.widget.base_estimator, KNNLearner)
50+
self.send_signal("Learner", None)
51+
self.assertEqual(self.widget.base_estimator,
52+
self.widget.DEFAULT_BASE_ESTIMATOR)

0 commit comments

Comments
 (0)