Skip to content

Commit 8f1daf0

Browse files
authored
Merge pull request #1595 from ales-erjavec/fix-roc-averaging-undefined
[FIX] ROC Analysis - Fix roc averaging
2 parents 9ef2e65 + e6dea04 commit 8f1daf0

File tree

2 files changed

+99
-22
lines changed

2 files changed

+99
-22
lines changed

Orange/widgets/evaluate/owrocanalysis.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def ROCData_from_results(results, clf_index, target):
9393
folds = results.folds if results.folds is not None else [slice(0, -1)]
9494
fold_curves = []
9595
for fold in folds:
96-
# TODO: Check for no FP or no TP
9796
points = roc_curve_for_fold(results, fold, clf_index, target)
9897
hull = roc_curve_convex_hull(points)
9998
c = ROCCurve(ROCPoints(*points), ROCPoints(*hull))
@@ -102,32 +101,49 @@ def ROCData_from_results(results, clf_index, target):
102101
curves = [fold.points for fold in fold_curves
103102
if fold.is_valid]
104103

105-
fpr, tpr, std = roc_curve_vertical_average(curves)
106-
thresh = numpy.zeros_like(fpr) * numpy.nan
107-
hull = roc_curve_convex_hull((fpr, tpr, thresh))
108-
v_avg = ROCAveragedVert(
109-
ROCPoints(fpr, tpr, thresh),
110-
ROCPoints(*hull),
111-
std
112-
)
104+
if curves:
105+
fpr, tpr, std = roc_curve_vertical_average(curves)
113106

114-
all_thresh = numpy.hstack([t for _, _, t in curves])
115-
all_thresh = numpy.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10)
116-
all_thresh = numpy.unique(all_thresh)[::-1]
117-
thresh = all_thresh[::max(all_thresh.size // 10, 1)]
107+
thresh = numpy.zeros_like(fpr) * numpy.nan
108+
hull = roc_curve_convex_hull((fpr, tpr, thresh))
109+
v_avg = ROCAveragedVert(
110+
ROCPoints(fpr, tpr, thresh),
111+
ROCPoints(*hull),
112+
std
113+
)
114+
else:
115+
# return an invalid vertical averaged ROC
116+
v_avg = ROCAveragedVert(
117+
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
118+
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
119+
numpy.array([])
120+
)
118121

119-
(fpr, fpr_std), (tpr, tpr_std) = \
120-
roc_curve_threshold_average(curves, thresh)
122+
if curves:
123+
all_thresh = numpy.hstack([t for _, _, t in curves])
124+
all_thresh = numpy.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10)
125+
all_thresh = numpy.unique(all_thresh)[::-1]
126+
thresh = all_thresh[::max(all_thresh.size // 10, 1)]
121127

122-
hull = roc_curve_convex_hull((fpr, tpr, thresh))
128+
(fpr, fpr_std), (tpr, tpr_std) = \
129+
roc_curve_threshold_average(curves, thresh)
123130

124-
t_avg = ROCAveragedThresh(
125-
ROCPoints(fpr, tpr, thresh),
126-
ROCPoints(*hull),
127-
tpr_std,
128-
fpr_std
129-
)
131+
hull = roc_curve_convex_hull((fpr, tpr, thresh))
130132

133+
t_avg = ROCAveragedThresh(
134+
ROCPoints(fpr, tpr, thresh),
135+
ROCPoints(*hull),
136+
tpr_std,
137+
fpr_std
138+
)
139+
else:
140+
# return an invalid threshold averaged ROC
141+
t_avg = ROCAveragedThresh(
142+
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
143+
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
144+
numpy.array([]),
145+
numpy.array([])
146+
)
131147
return ROCData(merged_curve, fold_curves, v_avg, t_avg)
132148

133149
ROCData.from_results = staticmethod(ROCData_from_results)
@@ -670,6 +686,8 @@ def roc_curve_for_fold(res, fold, clf_idx, target):
670686

671687

672688
def roc_curve_vertical_average(curves, samples=10):
689+
if not len(curves):
690+
raise ValueError("No curves")
673691
fpr_sample = numpy.linspace(0.0, 1.0, samples)
674692
tpr_samples = []
675693
for fpr, tpr, _ in curves:
@@ -680,6 +698,8 @@ def roc_curve_vertical_average(curves, samples=10):
680698

681699

682700
def roc_curve_threshold_average(curves, thresh_samples):
701+
if not len(curves):
702+
raise ValueError("No curves")
683703
fpr_samples, tpr_samples = [], []
684704
for fpr, tpr, thresh in curves:
685705
ind = numpy.searchsorted(thresh[::-1], thresh_samples, side="left")
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import unittest
2+
import numpy
3+
4+
import Orange.data
5+
import Orange.evaluation
6+
import Orange.classification
7+
8+
from Orange.widgets.evaluate import owrocanalysis
9+
10+
11+
class TestROC(unittest.TestCase):
12+
def test_ROCData_from_results(self):
13+
data = Orange.data.Table("iris")
14+
learners = [
15+
Orange.classification.MajorityLearner(),
16+
Orange.classification.LogisticRegressionLearner(),
17+
Orange.classification.TreeLearner()
18+
]
19+
res = Orange.evaluation.CrossValidation(data, learners, k=10)
20+
21+
for i, _ in enumerate(learners):
22+
for c in range(len(data.domain.class_var.values)):
23+
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
24+
self.assertTrue(rocdata.merged.is_valid)
25+
self.assertEqual(len(rocdata.folds), 10)
26+
self.assertTrue(all(c.is_valid for c in rocdata.folds))
27+
self.assertTrue(rocdata.avg_vertical.is_valid)
28+
self.assertTrue(rocdata.avg_threshold.is_valid)
29+
30+
data = data[numpy.random.choice(len(data), size=20)]
31+
res = Orange.evaluation.LeaveOneOut(data, learners)
32+
33+
for i, _ in enumerate(learners):
34+
for c in range(len(data.domain.class_var.values)):
35+
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
36+
self.assertTrue(rocdata.merged.is_valid)
37+
self.assertEqual(len(rocdata.folds), 20)
38+
# all individual fold curves and averaged curve data
39+
# should be invalid
40+
self.assertTrue(all(not c.is_valid for c in rocdata.folds))
41+
self.assertFalse(rocdata.avg_vertical.is_valid)
42+
self.assertFalse(rocdata.avg_threshold.is_valid)
43+
44+
# equivalent test to the LeaveOneOut but from a slightly different
45+
# constructed Orange.evaluation.Results
46+
res = Orange.evaluation.CrossValidation(data, learners, k=20)
47+
48+
for i, _ in enumerate(learners):
49+
for c in range(len(data.domain.class_var.values)):
50+
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
51+
self.assertTrue(rocdata.merged.is_valid)
52+
self.assertEqual(len(rocdata.folds), 20)
53+
# all individual fold curves and averaged curve data
54+
# should be invalid
55+
self.assertTrue(all(not c.is_valid for c in rocdata.folds))
56+
self.assertFalse(rocdata.avg_vertical.is_valid)
57+
self.assertFalse(rocdata.avg_threshold.is_valid)

0 commit comments

Comments
 (0)