Skip to content

Commit eb9c1e0

Browse files
authored
Merge pull request #3172 from matejklemen/enh_roc_thresholds
[ENH] ROC analysis: show thresholds
2 parents a95245e + 099ee74 commit eb9c1e0

File tree

4 files changed

+144
-17
lines changed

4 files changed

+144
-17
lines changed

Orange/widgets/evaluate/owrocanalysis.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import numpy
1111
import sklearn.metrics as skl_metrics
1212

13-
from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction
14-
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont
13+
from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, QToolTip
14+
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, QCursor
1515
from AnyQt.QtCore import Qt
1616
import pyqtgraph as pg
1717

@@ -336,6 +336,7 @@ def __init__(self):
336336
self._plot_curves = {}
337337
self._rocch = None
338338
self._perf_line = None
339+
self._tooltip_cache = None
339340

340341
box = gui.vBox(self.controlArea, "Plot")
341342
tbox = gui.vBox(box, "Target Class")
@@ -395,6 +396,7 @@ def __init__(self):
395396

396397
self.plotview = pg.GraphicsView(background="w")
397398
self.plotview.setFrameStyle(QFrame.StyledPanel)
399+
self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved)
398400

399401
self.plot = pg.PlotItem(enableMenu=False)
400402
self.plot.setMouseEnabled(False, False)
@@ -445,6 +447,7 @@ def clear(self):
445447
self._plot_curves = {}
446448
self._rocch = None
447449
self._perf_line = None
450+
self._tooltip_cache = None
448451

449452
def _initialize(self, results):
450453
names = getattr(results, "learner_names", None)
@@ -601,6 +604,68 @@ def _setup_plot(self):
601604
warning = "All ROC curves are undefined"
602605
self.warning(warning)
603606

607+
def _on_mouse_moved(self, pos):
608+
target = self.target_index
609+
selected = self.selected_classifiers
610+
curves = [(clf_idx, self.plot_curves(target, clf_idx))
611+
for clf_idx in selected] # type: List[Tuple[int, plot_curves]]
612+
valid_thresh, valid_clf = [], []
613+
pt, ave_mode = None, self.roc_averaging
614+
615+
for clf_idx, crv in curves:
616+
if self.roc_averaging == OWROCAnalysis.Merge:
617+
curve = crv.merge()
618+
elif self.roc_averaging == OWROCAnalysis.Vertical:
619+
curve = crv.avg_vertical()
620+
elif self.roc_averaging == OWROCAnalysis.Threshold:
621+
curve = crv.avg_threshold()
622+
else:
623+
# currently not implemented for 'Show Individual Curves'
624+
return
625+
626+
sp = curve.curve_item.childItems()[0] # type: pg.ScatterPlotItem
627+
act_pos = sp.mapFromScene(pos)
628+
pts = sp.pointsAt(act_pos)
629+
630+
if len(pts) > 0:
631+
mouse_pt = pts[0].pos()
632+
if self._tooltip_cache:
633+
cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache
634+
curr_thresh, curr_clf = [], []
635+
if numpy.linalg.norm(mouse_pt - cache_pt) < 10e-6 \
636+
and cache_ave == self.roc_averaging:
637+
mask = numpy.equal(cache_clf, clf_idx)
638+
curr_thresh = numpy.compress(mask, cache_thresh).tolist()
639+
curr_clf = numpy.compress(mask, cache_clf).tolist()
640+
else:
641+
QToolTip.showText(QCursor.pos(), "")
642+
self._tooltip_cache = None
643+
644+
if curr_thresh:
645+
valid_thresh.append(*curr_thresh)
646+
valid_clf.append(*curr_clf)
647+
pt = cache_pt
648+
continue
649+
650+
curve_pts = curve.curve.points
651+
roc_points = numpy.column_stack((curve_pts.fpr, curve_pts.tpr))
652+
diff = numpy.subtract(roc_points, mouse_pt)
653+
# Find closest point on curve and save the corresponding threshold
654+
idx_closest = numpy.argmin(numpy.linalg.norm(diff, axis=1))
655+
656+
thresh = curve_pts.thresholds[idx_closest]
657+
if not numpy.isnan(thresh):
658+
valid_thresh.append(thresh)
659+
valid_clf.append(clf_idx)
660+
pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]]
661+
662+
if valid_thresh:
663+
clf_names = self.classifier_names
664+
msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh)
665+
for i, thresh in zip(valid_clf, valid_thresh)])
666+
QToolTip.showText(QCursor.pos(), msg)
667+
self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode)
668+
604669
def _on_target_changed(self):
605670
self.plot.clear()
606671
self._setup_plot()

Orange/widgets/evaluate/tests/test_owrocanalysis.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import copy
55
import numpy as np
66

7+
from AnyQt.QtWidgets import QToolTip
8+
79
import Orange.data
810
import Orange.evaluation
911
import Orange.classification
@@ -12,6 +14,7 @@
1214
from Orange.widgets.evaluate.owrocanalysis import OWROCAnalysis
1315
from Orange.widgets.evaluate.tests.base import EvaluateTest
1416
from Orange.widgets.tests.base import WidgetTest
17+
from Orange.widgets.tests.utils import mouseMove
1518

1619

1720
class TestROC(unittest.TestCase):
@@ -156,3 +159,60 @@ def test_nan_input(self):
156159
self.assertTrue(self.widget.Error.invalid_results.is_shown())
157160
self.send_signal(self.widget.Inputs.evaluation_results, None)
158161
self.assertFalse(self.widget.Error.invalid_results.is_shown())
162+
163+
def test_tooltips(self):
164+
data_in = Orange.data.Table("titanic")
165+
res = Orange.evaluation.TestOnTrainingData(
166+
data=data_in,
167+
learners=[Orange.classification.KNNLearner(),
168+
Orange.classification.LogisticRegressionLearner()],
169+
store_data=True
170+
)
171+
172+
self.send_signal(self.widget.Inputs.evaluation_results, res)
173+
self.widget.roc_averaging = OWROCAnalysis.Merge
174+
self.widget.target_index = 0
175+
self.widget.selected_classifiers = [0, 1]
176+
vb = self.widget.plot.getViewBox()
177+
vb.childTransform() # Force pyqtgraph to update transforms
178+
179+
curve = self.widget.plot_curves(self.widget.target_index, 0)
180+
curve_merge = curve.merge()
181+
view = self.widget.plotview
182+
item = curve_merge.curve_item # type: pg.PlotCurveItem
183+
184+
# no tooltips to be shown
185+
pos = item.mapToScene(0.0, 1.0)
186+
pos = view.mapFromScene(pos)
187+
mouseMove(view.viewport(), pos)
188+
self.assertIs(self.widget._tooltip_cache, None)
189+
190+
# test single point
191+
pos = item.mapToScene(0.22504, 0.45400)
192+
pos = view.mapFromScene(pos)
193+
mouseMove(view.viewport(), pos)
194+
shown_thresh = self.widget._tooltip_cache[1]
195+
self.assertTrue(QToolTip.isVisible())
196+
np.testing.assert_almost_equal(shown_thresh, [0.40000], decimal=5)
197+
198+
pos = item.mapToScene(0.0, 0.0)
199+
pos = view.mapFromScene(pos)
200+
# test overlapping points
201+
mouseMove(view.viewport(), pos)
202+
shown_thresh = self.widget._tooltip_cache[1]
203+
self.assertTrue(QToolTip.isVisible())
204+
np.testing.assert_almost_equal(shown_thresh, [1.8, 1.89336], decimal=5)
205+
206+
# test that cache is invalidated when changing averaging mode
207+
self.widget.roc_averaging = OWROCAnalysis.Threshold
208+
self.widget._replot()
209+
mouseMove(view.viewport(), pos)
210+
shown_thresh = self.widget._tooltip_cache[1]
211+
self.assertTrue(QToolTip.isVisible())
212+
np.testing.assert_almost_equal(shown_thresh, [1, 1])
213+
214+
# test nan thresholds
215+
self.widget.roc_averaging = OWROCAnalysis.Vertical
216+
self.widget._replot()
217+
mouseMove(view.viewport(), pos)
218+
self.assertIs(self.widget._tooltip_cache, None)

Orange/widgets/tests/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import warnings
33
import contextlib
44

5-
from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale
5+
from AnyQt.QtCore import Qt, QObject, QEventLoop, QTimer, QLocale, QPoint
66
from AnyQt.QtTest import QTest
7+
from AnyQt.QtGui import QMouseEvent
8+
from AnyQt.QtWidgets import QApplication
79

810

911
class EventSpy(QObject):
@@ -303,3 +305,15 @@ def wrap(*args, **kwargs):
303305
return result
304306
return wrap
305307
return wrapper
308+
309+
310+
def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
311+
# Like QTest.mouseMove, but functional without QCursor.setPos
312+
if pos.isNull():
313+
pos = widget.rect().center()
314+
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
315+
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
316+
if delay > 0:
317+
QTest.qWait(delay)
318+
319+
QApplication.sendEvent(widget, me)

Orange/widgets/utils/tests/test_combobox.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import unittest
44

5-
from AnyQt.QtCore import Qt, QPoint, QRect
6-
from AnyQt.QtGui import QMouseEvent
5+
from AnyQt.QtCore import Qt, QRect
76
from AnyQt.QtWidgets import QListView, QApplication
87
from AnyQt.QtTest import QTest, QSignalSpy
98
from Orange.widgets.tests.base import GuiTest
9+
from Orange.widgets.tests.utils import mouseMove
1010

1111
from Orange.widgets.utils import combobox
1212

@@ -133,15 +133,3 @@ def test_popup_util(self):
133133
geom, QRect(0, 500, 100, 20), screen
134134
)
135135
self.assertEqual(g4, QRect(0, 500 - 400, 100, 400))
136-
137-
138-
def mouseMove(widget, pos=QPoint(), delay=-1): # pragma: no-cover
139-
# Like QTest.mouseMove, but functional without QCursor.setPos
140-
if pos.isNull():
141-
pos = widget.rect().center()
142-
me = QMouseEvent(QMouseEvent.MouseMove, pos, widget.mapToGlobal(pos),
143-
Qt.NoButton, Qt.MouseButtons(0), Qt.NoModifier)
144-
if delay > 0:
145-
QTest.qWait(delay)
146-
147-
QApplication.sendEvent(widget, me)

0 commit comments

Comments
 (0)