|
10 | 10 | import numpy |
11 | 11 | import sklearn.metrics as skl_metrics |
12 | 12 |
|
13 | | -from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction |
14 | | -from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont |
| 13 | +from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, QToolTip |
| 14 | +from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, QCursor |
15 | 15 | from AnyQt.QtCore import Qt |
16 | 16 | import pyqtgraph as pg |
17 | 17 |
|
@@ -336,6 +336,7 @@ def __init__(self): |
336 | 336 | self._plot_curves = {} |
337 | 337 | self._rocch = None |
338 | 338 | self._perf_line = None |
| 339 | + self._tooltip_cache = None |
339 | 340 |
|
340 | 341 | box = gui.vBox(self.controlArea, "Plot") |
341 | 342 | tbox = gui.vBox(box, "Target Class") |
@@ -395,6 +396,7 @@ def __init__(self): |
395 | 396 |
|
396 | 397 | self.plotview = pg.GraphicsView(background="w") |
397 | 398 | self.plotview.setFrameStyle(QFrame.StyledPanel) |
| 399 | + self.plotview.scene().sigMouseMoved.connect(self._on_mouse_moved) |
398 | 400 |
|
399 | 401 | self.plot = pg.PlotItem(enableMenu=False) |
400 | 402 | self.plot.setMouseEnabled(False, False) |
@@ -445,6 +447,7 @@ def clear(self): |
445 | 447 | self._plot_curves = {} |
446 | 448 | self._rocch = None |
447 | 449 | self._perf_line = None |
| 450 | + self._tooltip_cache = None |
448 | 451 |
|
449 | 452 | def _initialize(self, results): |
450 | 453 | names = getattr(results, "learner_names", None) |
@@ -601,6 +604,68 @@ def _setup_plot(self): |
601 | 604 | warning = "All ROC curves are undefined" |
602 | 605 | self.warning(warning) |
603 | 606 |
|
| 607 | + def _on_mouse_moved(self, pos): |
| 608 | + target = self.target_index |
| 609 | + selected = self.selected_classifiers |
| 610 | + curves = [(clf_idx, self.plot_curves(target, clf_idx)) |
| 611 | + for clf_idx in selected] # type: List[Tuple[int, plot_curves]] |
| 612 | + valid_thresh, valid_clf = [], [] |
| 613 | + pt, ave_mode = None, self.roc_averaging |
| 614 | + |
| 615 | + for clf_idx, crv in curves: |
| 616 | + if self.roc_averaging == OWROCAnalysis.Merge: |
| 617 | + curve = crv.merge() |
| 618 | + elif self.roc_averaging == OWROCAnalysis.Vertical: |
| 619 | + curve = crv.avg_vertical() |
| 620 | + elif self.roc_averaging == OWROCAnalysis.Threshold: |
| 621 | + curve = crv.avg_threshold() |
| 622 | + else: |
| 623 | + # currently not implemented for 'Show Individual Curves' |
| 624 | + return |
| 625 | + |
| 626 | + sp = curve.curve_item.childItems()[0] # type: pg.ScatterPlotItem |
| 627 | + act_pos = sp.mapFromScene(pos) |
| 628 | + pts = sp.pointsAt(act_pos) |
| 629 | + |
| 630 | + if len(pts) > 0: |
| 631 | + mouse_pt = pts[0].pos() |
| 632 | + if self._tooltip_cache: |
| 633 | + cache_pt, cache_thresh, cache_clf, cache_ave = self._tooltip_cache |
| 634 | + curr_thresh, curr_clf = [], [] |
| 635 | + if numpy.linalg.norm(mouse_pt - cache_pt) < 10e-6 \ |
| 636 | + and cache_ave == self.roc_averaging: |
| 637 | + mask = numpy.equal(cache_clf, clf_idx) |
| 638 | + curr_thresh = numpy.compress(mask, cache_thresh).tolist() |
| 639 | + curr_clf = numpy.compress(mask, cache_clf).tolist() |
| 640 | + else: |
| 641 | + QToolTip.showText(QCursor.pos(), "") |
| 642 | + self._tooltip_cache = None |
| 643 | + |
| 644 | + if curr_thresh: |
| 645 | + valid_thresh.append(*curr_thresh) |
| 646 | + valid_clf.append(*curr_clf) |
| 647 | + pt = cache_pt |
| 648 | + continue |
| 649 | + |
| 650 | + curve_pts = curve.curve.points |
| 651 | + roc_points = numpy.column_stack((curve_pts.fpr, curve_pts.tpr)) |
| 652 | + diff = numpy.subtract(roc_points, mouse_pt) |
| 653 | + # Find closest point on curve and save the corresponding threshold |
| 654 | + idx_closest = numpy.argmin(numpy.linalg.norm(diff, axis=1)) |
| 655 | + |
| 656 | + thresh = curve_pts.thresholds[idx_closest] |
| 657 | + if not numpy.isnan(thresh): |
| 658 | + valid_thresh.append(thresh) |
| 659 | + valid_clf.append(clf_idx) |
| 660 | + pt = [curve_pts.fpr[idx_closest], curve_pts.tpr[idx_closest]] |
| 661 | + |
| 662 | + if valid_thresh: |
| 663 | + clf_names = self.classifier_names |
| 664 | + msg = "Thresholds:\n" + "\n".join(["({:s}) {:.3f}".format(clf_names[i], thresh) |
| 665 | + for i, thresh in zip(valid_clf, valid_thresh)]) |
| 666 | + QToolTip.showText(QCursor.pos(), msg) |
| 667 | + self._tooltip_cache = (pt, valid_thresh, valid_clf, ave_mode) |
| 668 | + |
604 | 669 | def _on_target_changed(self): |
605 | 670 | self.plot.clear() |
606 | 671 | self._setup_plot() |
|
0 commit comments