Skip to content

Commit 9c593cc

Browse files
authored
Merge pull request #5138 from janezd/roc-point-fixes
[FIX] ROC shows all points, including the last
2 parents 0a83c6c + 1eca4d6 commit 9c593cc

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

Orange/widgets/evaluate/owrocanalysis.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
"""
2-
ROC Analysis Widget
3-
-------------------
4-
5-
"""
61
import operator
72
from functools import reduce, wraps
83
from collections import namedtuple, deque, OrderedDict
@@ -11,7 +6,7 @@
116
import sklearn.metrics as skl_metrics
127

138
from AnyQt.QtWidgets import QListView, QLabel, QGridLayout, QFrame, QAction, \
14-
QToolTip, QSizePolicy
9+
QToolTip
1510
from AnyQt.QtGui import QColor, QPen, QBrush, QPainter, QPalette, QFont, \
1611
QCursor, QFontMetrics
1712
from AnyQt.QtCore import Qt, QSize
@@ -21,13 +16,15 @@
2116
from Orange.widgets import widget, gui, settings
2217
from Orange.widgets.evaluate.contexthandlers import \
2318
EvaluationResultsContextHandler
24-
from Orange.widgets.evaluate.utils import \
25-
check_results_adequacy, results_for_preview
19+
from Orange.widgets.evaluate.utils import check_results_adequacy
2620
from Orange.widgets.utils import colorpalettes
2721
from Orange.widgets.utils.widgetpreview import WidgetPreview
2822
from Orange.widgets.widget import Input
2923
from Orange.widgets import report
3024

25+
from Orange.widgets.evaluate.utils import results_for_preview
26+
from Orange.evaluation.testing import Results
27+
3128

3229
#: Points on a ROC curve
3330
ROCPoints = namedtuple(
@@ -93,11 +90,11 @@ def roc_data_from_results(results, clf_index, target):
9390
:rval ROCData:
9491
A instance holding the computed curves.
9592
"""
96-
merged = roc_curve_for_fold(results, slice(0, -1), clf_index, target)
93+
merged = roc_curve_for_fold(results, ..., clf_index, target)
9794
merged_curve = ROCCurve(ROCPoints(*merged),
9895
ROCPoints(*roc_curve_convex_hull(merged)))
9996

100-
folds = results.folds if results.folds is not None else [slice(0, -1)]
97+
folds = results.folds if results.folds is not None else [...]
10198
fold_curves = []
10299
for fold in folds:
103100
points = roc_curve_for_fold(results, fold, clf_index, target)
@@ -413,11 +410,13 @@ def __init__(self):
413410
axis.setTickFont(tickfont)
414411
axis.setPen(pen)
415412
axis.setLabel("FP Rate (1-Specificity)")
413+
axis.setGrid(16)
416414

417415
axis = self.plot.getAxis("left")
418416
axis.setTickFont(tickfont)
419417
axis.setPen(pen)
420418
axis.setLabel("TP Rate (Sensitivity)")
419+
axis.setGrid(16)
421420

422421
self.plot.showGrid(True, True, alpha=0.1)
423422
self.plot.setRange(xRange=(0.0, 1.0), yRange=(0.0, 1.0), padding=0.05)
@@ -621,6 +620,8 @@ def no_averaging():
621620
if self.roc_averaging == OWROCAnalysis.Merge:
622621
self._update_perf_line()
623622

623+
self._update_axes_ticks()
624+
624625
warning = ""
625626
if not all(c.is_valid for c in hull_curves):
626627
if any(c.is_valid for c in hull_curves):
@@ -629,6 +630,22 @@ def no_averaging():
629630
warning = "All ROC curves are undefined"
630631
self.warning(warning)
631632

633+
def _update_axes_ticks(self):
634+
def enumticks(a):
635+
a = np.unique(a)
636+
if len(a) > 15:
637+
return None
638+
return [[(x, f"{x:.2f}") for x in a[::-1]]]
639+
640+
data = self.curve_data(self.target_index, self.selected_classifiers[0])
641+
points = data.merged.points
642+
643+
axis = self.plot.getAxis("bottom")
644+
axis.setTicks(enumticks(points.fpr))
645+
646+
axis = self.plot.getAxis("left")
647+
axis.setTicks(enumticks(points.tpr))
648+
632649
def _on_mouse_moved(self, pos):
633650
target = self.target_index
634651
selected = self.selected_classifiers
@@ -802,10 +819,19 @@ def roc_curve_for_fold(res, fold, clf_idx, target):
802819
return np.array([]), np.array([]), np.array([])
803820

804821
fold_probs = res.probabilities[clf_idx][fold][:, target]
805-
return skl_metrics.roc_curve(
806-
fold_actual, fold_probs, pos_label=target
822+
drop_intermediate = len(fold_actual) > 20
823+
fpr, tpr, thresholds = skl_metrics.roc_curve(
824+
fold_actual, fold_probs, pos_label=target,
825+
drop_intermediate=drop_intermediate
807826
)
808827

828+
# skl sets the first threshold to the highest threshold in the data + 1
829+
# since we deal with probabilities, we (carefully) set it to 1
830+
# Unrelated comparisons, thus pylint: disable=chained-comparison
831+
if len(thresholds) > 1 and thresholds[1] <= 1:
832+
thresholds[0] = 1
833+
return fpr, tpr, thresholds
834+
809835

810836
def roc_curve_vertical_average(curves, samples=10):
811837
if not curves:
@@ -969,5 +995,19 @@ def roc_iso_performance_slope(fp_cost, fn_cost, p):
969995
return (fp_cost * (1. - p)) / (fn_cost * p)
970996

971997

998+
def _create_results(): # pragma: no cover
999+
probs1 = [0.984, 0.907, 0.881, 0.865, 0.815, 0.741, 0.735, 0.635,
1000+
0.582, 0.554, 0.413, 0.317, 0.287, 0.225, 0.216, 0.183]
1001+
probs = np.array([[[1 - x, x] for x in probs1]])
1002+
preds = (probs > 0.5).astype(float)
1003+
return Results(
1004+
data=Orange.data.Table("heart_disease")[:16],
1005+
row_indices=np.arange(16),
1006+
actual=np.array(list(map(int, "1100111001001000"))),
1007+
probabilities=probs, predicted=preds
1008+
)
1009+
1010+
9721011
if __name__ == "__main__": # pragma: no cover
1012+
# WidgetPreview(OWROCAnalysis).run(_create_results())
9731013
WidgetPreview(OWROCAnalysis).run(results_for_preview())

Orange/widgets/evaluate/tests/test_owrocanalysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_tooltips(self):
214214
pos = view.mapFromScene(pos)
215215
mouseMove(view.viewport(), pos)
216216
(_, text), _ = show_text.call_args
217-
self.assertIn("(#1) 1.800\n(#2) 1.893", text)
217+
self.assertIn("(#1) 1.000\n(#2) 1.000", text)
218218

219219
# test that cache is invalidated when changing averaging mode
220220
self.widget.roc_averaging = OWROCAnalysis.Threshold

0 commit comments

Comments
 (0)