Skip to content

Commit 566bc51

Browse files
committed
Parameter Fitter: Tests
1 parent 42c9de4 commit 566bc51

File tree

2 files changed

+190
-19
lines changed

2 files changed

+190
-19
lines changed

Orange/widgets/evaluate/owparameterfitter.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,11 @@ def update_setters(self):
126126
def update_grid(**settings):
127127
self.grid_settings.update(**settings)
128128
self.master.showGrid(
129-
x=self.grid_settings[self.SHOW_GRID_LABEL],
130-
y=self.grid_settings[self.SHOW_GRID_LABEL],
129+
x=False, y=self.grid_settings[self.SHOW_GRID_LABEL],
131130
alpha=self.grid_settings[Updater.ALPHA_LABEL] / 255)
132131

133132
self._setters[self.PLOT_BOX] = {self.GRID_LABEL: update_grid}
134133

135-
@property
136-
def title_item(self):
137-
return self.master.getPlotItem().titleLabel
138-
139134
@property
140135
def axis_items(self):
141136
return [value["item"] for value in
@@ -159,8 +154,7 @@ def __init__(self):
159154
self.setMouseEnabled(False, False)
160155
self.hideButtons()
161156

162-
self.showGrid(False, True)
163-
self.showGrid(y=self.parameter_setter.DEFAULT_SHOW_GRID,
157+
self.showGrid(x=False, y=self.parameter_setter.DEFAULT_SHOW_GRID,
164158
alpha=self.parameter_setter.DEFAULT_ALPHA_GRID / 255)
165159

166160
self.tooltip_delegate = HelpEventDelegate(self.help_event)
@@ -284,7 +278,6 @@ class Inputs:
284278
auto_commit = Setting(True)
285279

286280
class Error(OWWidget.Error):
287-
domain_transform_err = Msg("{}")
288281
unknown_err = Msg("{}")
289282
not_enough_data = Msg(f"At least {N_FOLD} instances are needed.")
290283
incompatible_learner = Msg("{}")
@@ -409,7 +402,6 @@ def handleNewSignals(self):
409402
self.Warning.no_parameters.clear()
410403
self.Error.incompatible_learner.clear()
411404
self.Error.unknown_err.clear()
412-
self.Error.domain_transform_err.clear()
413405
self.clear()
414406
if self._data is None or self._learner is None:
415407
return
@@ -454,8 +446,8 @@ def _set_range_controls(self):
454446
self.__spin_max.setMinimum(-MIN_MAX_SPIN)
455447
self.minimum = self.initial_parameters[param.parameter_name]
456448
if param.max is not None:
457-
self.__spin_min.setMaximum(param.min)
458-
self.__spin_max.setMaximum(param.min)
449+
self.__spin_min.setMaximum(param.max)
450+
self.__spin_max.setMaximum(param.max)
459451
self.maximum = param.max
460452
else:
461453
self.__spin_min.setMaximum(MIN_MAX_SPIN)
@@ -484,10 +476,7 @@ def on_done(self, result: FitterResults):
484476
self.graph.set_data(*result)
485477

486478
def on_exception(self, ex: Exception):
487-
if isinstance(ex, DomainTransformationError):
488-
self.Error.domain_transform_err(ex)
489-
else:
490-
self.Error.unknown_err(ex)
479+
self.Error.unknown_err(ex)
491480

492481
def on_partial_result(self, _):
493482
pass

Orange/widgets/evaluate/tests/test_owparameterfitter.py

Lines changed: 185 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
# pylint: disable=missing-docstring,protected-access
22
import unittest
3+
from unittest.mock import patch, Mock
4+
5+
import pyqtgraph as pg
6+
7+
from AnyQt.QtCore import QPointF
8+
from AnyQt.QtGui import QFont
9+
from AnyQt.QtWidgets import QToolTip
310

411
from Orange.classification import NaiveBayesLearner
512
from Orange.data import Table
13+
from Orange.modelling import RandomForestLearner
614
from Orange.regression import PLSRegressionLearner
715
from Orange.widgets.evaluate.owparameterfitter import OWParameterFitter
816
from Orange.widgets.model.owrandomforest import OWRandomForest
917
from Orange.widgets.tests.base import WidgetTest
18+
from Orange.widgets.tests.utils import simulate
19+
20+
21+
class DummyLearner(PLSRegressionLearner):
22+
def fitted_parameters(self):
23+
return [
24+
self.FittedParameter("n_components", "Foo", "foo", int, 1, None),
25+
self.FittedParameter("n_components", "Bar", "bar", int, 1, 10),
26+
self.FittedParameter("n_components", "Baz", "baz", int, None, 10),
27+
self.FittedParameter("n_components", "Qux", "qux", int, None, None)
28+
]
1029

1130

1231
class TestOWParameterFitter(WidgetTest):
@@ -17,6 +36,8 @@ def setUpClass(cls):
1736
cls._housing = Table("housing")
1837
cls._naive_bayes = NaiveBayesLearner()
1938
cls._pls = PLSRegressionLearner()
39+
cls._rf = RandomForestLearner()
40+
cls._dummy = DummyLearner()
2041

2142
def setUp(self):
2243
self.widget = self.create_widget(OWParameterFitter)
@@ -48,21 +69,18 @@ def test_random_forest(self):
4869

4970
self.send_signal(self.widget.Inputs.learner, learner)
5071
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
51-
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
5272
self.assertFalse(self.widget.Error.unknown_err.is_shown())
5373
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
5474
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())
5575

5676
self.send_signal(self.widget.Inputs.data, self._heart)
5777
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
58-
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
5978
self.assertFalse(self.widget.Error.unknown_err.is_shown())
6079
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
6180
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())
6281

6382
self.send_signal(self.widget.Inputs.data, self._housing)
6483
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
65-
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
6684
self.assertFalse(self.widget.Error.unknown_err.is_shown())
6785
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
6886
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())
@@ -77,6 +95,31 @@ def test_plot(self):
7795
x = self.widget.graph._FitterPlot__bar_item_cv.opts["x"]
7896
self.assertEqual(list(x), [0.2, 1.2])
7997

98+
@patch.object(QToolTip, "showText")
99+
def test_tooltip(self, show_text):
100+
graph = self.widget.graph
101+
102+
self.assertFalse(self.widget.graph.help_event(Mock()))
103+
self.assertIsNone(show_text.call_args)
104+
105+
self.send_signal(self.widget.Inputs.data, self._housing)
106+
self.send_signal(self.widget.Inputs.learner, self._pls)
107+
self.wait_until_finished()
108+
109+
for item in graph.items():
110+
if isinstance(item, pg.BarGraphItem):
111+
item.mapFromScene = Mock(return_value=QPointF(0.2, 0.2))
112+
113+
self.assertTrue(self.widget.graph.help_event(Mock()))
114+
self.assertIn("Train:", show_text.call_args[0][1])
115+
self.assertIn("CV:", show_text.call_args[0][1])
116+
117+
for item in graph.items():
118+
if isinstance(item, pg.BarGraphItem):
119+
item.mapFromScene = Mock(return_value=QPointF(0.5, 0.5))
120+
self.assertFalse(self.widget.graph.help_event(Mock()))
121+
122+
80123
def test_manual_steps(self):
81124
self.send_signal(self.widget.Inputs.data, self._housing)
82125
self.send_signal(self.widget.Inputs.learner, self._pls)
@@ -106,6 +149,145 @@ def test_steps_preview(self):
106149
self.wait_until_finished()
107150
self.assertEqual(self.widget.preview, "[10, 15, 20, 25]")
108151

152+
def test_on_parameter_changed(self):
153+
self.send_signal(self.widget.Inputs.data, self._housing)
154+
self.send_signal(self.widget.Inputs.learner, self._dummy)
155+
self.wait_until_finished()
156+
157+
self.widget.commit.deferred = Mock()
158+
159+
for i in range(1, 4):
160+
self.widget.commit.deferred.reset_mock()
161+
simulate.combobox_activate_index(
162+
self.widget.controls.parameter_index, i)
163+
self.wait_until_finished()
164+
self.widget.commit.deferred.assert_called_once()
165+
166+
def test_not_enough_data(self):
167+
self.send_signal(self.widget.Inputs.data, self._housing[:5])
168+
self.send_signal(self.widget.Inputs.learner, self._pls)
169+
self.wait_until_finished()
170+
self.assertTrue(self.widget.Error.not_enough_data.is_shown())
171+
self.send_signal(self.widget.Inputs.data, None)
172+
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
173+
174+
def test_unknown_err(self):
175+
self.send_signal(self.widget.Inputs.data, Table("iris")[:50])
176+
self.send_signal(self.widget.Inputs.learner, self._rf)
177+
self.wait_until_finished()
178+
self.assertTrue(self.widget.Error.unknown_err.is_shown())
179+
self.send_signal(self.widget.Inputs.data, None)
180+
self.assertFalse(self.widget.Error.unknown_err.is_shown())
181+
182+
def test_fitted_parameters(self):
183+
self.assertEqual(self.widget.fitted_parameters, [])
184+
185+
self.send_signal(self.widget.Inputs.data, self._housing)
186+
self.assertEqual(self.widget.fitted_parameters, [])
187+
188+
self.send_signal(self.widget.Inputs.learner, self._pls)
189+
self.assertEqual(len(self.widget.fitted_parameters), 1)
190+
self.wait_until_finished()
191+
192+
self.send_signal(self.widget.Inputs.data, None)
193+
self.assertEqual(self.widget.fitted_parameters, [])
194+
195+
def test_initial_parameters(self):
196+
self.assertEqual(self.widget.initial_parameters, {})
197+
198+
self.send_signal(self.widget.Inputs.data, self._housing)
199+
self.assertEqual(self.widget.initial_parameters, {})
200+
201+
self.send_signal(self.widget.Inputs.learner, self._pls)
202+
self.assertEqual(len(self.widget.initial_parameters), 3)
203+
self.wait_until_finished()
204+
205+
self.send_signal(self.widget.Inputs.learner, self._rf)
206+
self.assertEqual(len(self.widget.initial_parameters), 13)
207+
self.wait_until_finished()
208+
209+
self.send_signal(self.widget.Inputs.data, None)
210+
self.assertEqual(self.widget.initial_parameters, {})
211+
212+
def test_saved_workflow(self):
213+
self.send_signal(self.widget.Inputs.data, self._housing)
214+
self.send_signal(self.widget.Inputs.learner, self._dummy)
215+
self.wait_until_finished()
216+
simulate.combobox_activate_index(
217+
self.widget.controls.parameter_index, 2)
218+
self.widget.controls.minimum.setValue(3)
219+
self.widget.controls.maximum.setValue(6)
220+
self.wait_until_finished()
221+
222+
settings = self.widget.settingsHandler.pack_data(self.widget)
223+
widget = self.create_widget(OWParameterFitter,
224+
stored_settings=settings)
225+
self.send_signal(widget.Inputs.data, self._housing, widget=widget)
226+
self.send_signal(widget.Inputs.learner, self._dummy, widget=widget)
227+
self.wait_until_finished(widget=widget)
228+
self.assertEqual(widget.controls.parameter_index.currentText(), "Baz")
229+
self.assertEqual(widget.minimum, 3)
230+
self.assertEqual(widget.maximum, 6)
231+
232+
def test_visual_settings(self):
233+
graph = self.widget.graph
234+
235+
def test_settings():
236+
font = QFont("Helvetica", italic=True, pointSize=20)
237+
for item in graph.parameter_setter.axis_items:
238+
self.assertFontEqual(item.label.font(), font)
239+
font.setPointSize(15)
240+
for item in graph.parameter_setter.axis_items:
241+
self.assertFontEqual(item.style["tickFont"], font)
242+
font.setPointSize(17)
243+
for legend_item in graph.parameter_setter.legend_items:
244+
self.assertFontEqual(legend_item[1].item.font(), font)
245+
self.assertFalse(graph.getAxis("left").grid)
246+
247+
key, value = ("Fonts", "Font family", "Font family"), "Helvetica"
248+
self.widget.set_visual_settings(key, value)
249+
250+
key, value = ("Fonts", "Axis title", "Font size"), 20
251+
self.widget.set_visual_settings(key, value)
252+
key, value = ("Fonts", "Axis title", "Italic"), True
253+
self.widget.set_visual_settings(key, value)
254+
255+
key, value = ("Fonts", "Axis ticks", "Font size"), 15
256+
self.widget.set_visual_settings(key, value)
257+
key, value = ("Fonts", "Axis ticks", "Italic"), True
258+
self.widget.set_visual_settings(key, value)
259+
260+
key, value = ("Fonts", "Legend", "Font size"), 17
261+
self.widget.set_visual_settings(key, value)
262+
key, value = ("Fonts", "Legend", "Italic"), True
263+
self.widget.set_visual_settings(key, value)
264+
265+
key, value = ("Figure", "Gridlines", "Show"), False
266+
self.widget.set_visual_settings(key, value)
267+
key, value = ("Figure", "Gridlines", "Opacity"), 20
268+
self.widget.set_visual_settings(key, value)
269+
270+
test_settings()
271+
272+
self.send_signal(self.widget.Inputs.learner, self._pls)
273+
self.send_signal(self.widget.Inputs.data, self._heart[:10])
274+
test_settings()
275+
276+
self.send_signal(self.widget.Inputs.data, None)
277+
self.send_signal(self.widget.Inputs.learner, None)
278+
279+
self.send_signal(self.widget.Inputs.learner, self._pls)
280+
self.send_signal(self.widget.Inputs.data, self._heart[:10])
281+
test_settings()
282+
283+
def assertFontEqual(self, font1: QFont, font2: QFont):
284+
self.assertEqual(font1.family(), font2.family())
285+
self.assertEqual(font1.pointSize(), font2.pointSize())
286+
self.assertEqual(font1.italic(), font2.italic())
287+
288+
def test_send_report(self):
289+
self.assertEqual(1, 2)
290+
109291

110292
if __name__ == "__main__":
111293
unittest.main()

0 commit comments

Comments
 (0)