Skip to content

Commit a4a0965

Browse files
committed
Permutation Plot: Add score info
1 parent 647eb31 commit a4a0965

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

Orange/widgets/evaluate/owpermutationplot.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Optional, Tuple, Callable, List, Dict
1+
from typing import Optional, Tuple, Callable, List, Dict, Union
22

33
import numpy as np
44
from scipy.stats import spearmanr, linregress
55
from AnyQt.QtCore import Qt
6+
from AnyQt.QtWidgets import QLabel
67
import pyqtgraph as pg
78

89
from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog, \
@@ -26,10 +27,20 @@
2627
from Orange.widgets.widget import OWWidget, Input, Msg
2728

2829
N_FOLD = 7
30+
# corr, scores_tr, intercept_tr, slope_tr,
31+
# scores_cv, intercept_cv, slope_cv, score_name
2932
PermutationResults = \
3033
Tuple[np.ndarray, List, float, float, List, float, float, str]
3134

3235

36+
def _f_lin(
37+
intercept: float,
38+
slope: float,
39+
x: Union[float, np.ndarray]
40+
) -> Union[float, np.ndarray]:
41+
return intercept + slope * x
42+
43+
3344
def _correlation(y: np.ndarray, y_pred: np.ndarray) -> float:
3445
return spearmanr(y, y_pred)[0] * 100
3546

@@ -192,8 +203,10 @@ def set_data(
192203

193204
x = np.array([0, 100])
194205
pen = pg.mkPen("#000", width=2, style=Qt.DashLine)
195-
line_tr = pg.PlotCurveItem(x, intercept_tr + slope_tr * x, pen=pen)
196-
line_cv = pg.PlotCurveItem(x, intercept_cv + slope_cv * x, pen=pen)
206+
y_tr = _f_lin(intercept_tr, slope_tr, x)
207+
y_cv = _f_lin(intercept_cv, slope_cv, x)
208+
line_tr = pg.PlotCurveItem(x, y_tr, pen=pen)
209+
line_cv = pg.PlotCurveItem(x, y_cv, pen=pen)
197210

198211
point_pen = pg.mkPen("#333")
199212
kwargs_tr = {"pen": point_pen, "symbol": "o", "brush": "#6fa255"}
@@ -230,7 +243,7 @@ class OWPermutationPlot(OWWidget, ConcurrentWidgetMixin):
230243

231244
class Inputs:
232245
data = Input("Data", Table)
233-
learner = Input("Lerner", Learner)
246+
learner = Input("Learner", Learner)
234247

235248
class Error(OWWidget.Error):
236249
domain_transform_err = Msg("{}")
@@ -243,6 +256,7 @@ def __init__(self):
243256
ConcurrentWidgetMixin.__init__(self)
244257
self._data: Optional[Table] = None
245258
self._learner: Optional[Learner] = None
259+
self._info: QLabel = None
246260
self.graph: PermutationPlot = None
247261
self.setup_gui()
248262
VisualSettingsDialog(
@@ -264,6 +278,37 @@ def _add_controls(self):
264278
minv=1, maxv=1000, callback=self._run)
265279
gui.rubber(self.controlArea)
266280

281+
box = gui.vBox(self.controlArea, "Info")
282+
self._info = gui.label(box, self, "", textFormat=Qt.RichText,
283+
minimumWidth=180)
284+
self.__set_info(None)
285+
286+
def __set_info(self, result: PermutationResults):
287+
html = "No data available."
288+
if result is not None:
289+
intercept_tr, slope_tr, _, intercept_cv, slope_cv = result[2: -1]
290+
html = """
291+
<table width=100% align="center" style="font-size:11px">
292+
<tr style="background:#fefefe">
293+
<th style="background:transparent;padding: 2px 4px"></th>
294+
<th style="background:transparent;padding: 2px 4px">Corr = 0</th>
295+
<th style="background:transparent;padding: 2px 4px">Corr = 100</th>
296+
</tr>
297+
<tr style="background:#fefefe">
298+
<th style="padding: 2px 4px" align=right>Train</th>
299+
<td style="padding: 2px 4px" align=right>{:.4f}</td>
300+
<td style="padding: 2px 4px" align=right>{:.4f}</td>
301+
</tr>
302+
<tr style="background:#fefefe">
303+
<th style="padding: 2px 4px" align=right>CV</th>
304+
<td style="padding: 2px 4px" align=right>{:.4f}</td>
305+
<td style="padding: 2px 4px" align=right>{:.4f}</td>
306+
</tr>
307+
</table>
308+
""".format(intercept_tr, _f_lin(intercept_tr, slope_tr, 100),
309+
intercept_cv, _f_lin(intercept_cv, slope_cv, 100))
310+
self._info.setText(html)
311+
267312
@Inputs.data
268313
@check_multiple_targets_input
269314
def set_data(self, data: Table):
@@ -296,6 +341,7 @@ def clear(self):
296341
self.cancel()
297342
self.graph.clear()
298343
self.graph.setTitle()
344+
self.__set_info(None)
299345

300346
def _run(self):
301347
if self._data is None or self._learner is None:
@@ -304,6 +350,7 @@ def _run(self):
304350

305351
def on_done(self, result: PermutationResults):
306352
self.graph.set_data(*result)
353+
self.__set_info(result)
307354

308355
def on_exception(self, ex: Exception):
309356
if isinstance(ex, DomainTransformationError):

Orange/widgets/evaluate/tests/test_owpermutationplot.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=missing-docstring
1+
# pylint: disable=missing-docstring,protected-access
22
import unittest
33

44
from Orange.classification import RandomForestLearner, \
@@ -87,6 +87,19 @@ def test_sample_data(self):
8787
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
8888
self.wait_until_finished()
8989

90+
def test_info(self):
91+
self.send_signal(self.widget.Inputs.learner, self.rf_cls)
92+
self.send_signal(self.widget.Inputs.data, self.heart)
93+
self.wait_until_finished()
94+
self.assertIn("0.5021", self.widget._info.text())
95+
self.assertIn('<th style="padding: 2px 4px" align=right>CV</th>',
96+
self.widget._info.text())
97+
self.assertIn('<th style="padding: 2px 4px" align=right>Train</th>',
98+
self.widget._info.text())
99+
100+
self.send_signal(self.widget.Inputs.learner, None)
101+
self.assertEqual(self.widget._info.text(), "No data available.")
102+
90103
def test_send_report(self):
91104
self.widget.send_report()
92105
self.send_signal(self.widget.Inputs.data, self.heart[:10])

0 commit comments

Comments
 (0)