Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions Orange/widgets/evaluate/owrocanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def ROCData_from_results(results, clf_index, target):
folds = results.folds if results.folds is not None else [slice(0, -1)]
fold_curves = []
for fold in folds:
# TODO: Check for no FP or no TP
points = roc_curve_for_fold(results, fold, clf_index, target)
hull = roc_curve_convex_hull(points)
c = ROCCurve(ROCPoints(*points), ROCPoints(*hull))
Expand All @@ -102,32 +101,49 @@ def ROCData_from_results(results, clf_index, target):
curves = [fold.points for fold in fold_curves
if fold.is_valid]

fpr, tpr, std = roc_curve_vertical_average(curves)
thresh = numpy.zeros_like(fpr) * numpy.nan
hull = roc_curve_convex_hull((fpr, tpr, thresh))
v_avg = ROCAveragedVert(
ROCPoints(fpr, tpr, thresh),
ROCPoints(*hull),
std
)
if curves:
fpr, tpr, std = roc_curve_vertical_average(curves)

all_thresh = numpy.hstack([t for _, _, t in curves])
all_thresh = numpy.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10)
all_thresh = numpy.unique(all_thresh)[::-1]
thresh = all_thresh[::max(all_thresh.size // 10, 1)]
thresh = numpy.zeros_like(fpr) * numpy.nan
hull = roc_curve_convex_hull((fpr, tpr, thresh))
v_avg = ROCAveragedVert(
ROCPoints(fpr, tpr, thresh),
ROCPoints(*hull),
std
)
else:
# return an invalid vertical averaged ROC
v_avg = ROCAveragedVert(
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
numpy.array([])
)

(fpr, fpr_std), (tpr, tpr_std) = \
roc_curve_threshold_average(curves, thresh)
if curves:
all_thresh = numpy.hstack([t for _, _, t in curves])
all_thresh = numpy.clip(all_thresh, 0.0 - 1e-10, 1.0 + 1e-10)
all_thresh = numpy.unique(all_thresh)[::-1]
thresh = all_thresh[::max(all_thresh.size // 10, 1)]

hull = roc_curve_convex_hull((fpr, tpr, thresh))
(fpr, fpr_std), (tpr, tpr_std) = \
roc_curve_threshold_average(curves, thresh)

t_avg = ROCAveragedThresh(
ROCPoints(fpr, tpr, thresh),
ROCPoints(*hull),
tpr_std,
fpr_std
)
hull = roc_curve_convex_hull((fpr, tpr, thresh))

t_avg = ROCAveragedThresh(
ROCPoints(fpr, tpr, thresh),
ROCPoints(*hull),
tpr_std,
fpr_std
)
else:
# return an invalid threshold averaged ROC
t_avg = ROCAveragedThresh(
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
ROCPoints(numpy.array([]), numpy.array([]), numpy.array([])),
numpy.array([]),
numpy.array([])
)
return ROCData(merged_curve, fold_curves, v_avg, t_avg)

ROCData.from_results = staticmethod(ROCData_from_results)
Expand Down Expand Up @@ -670,6 +686,8 @@ def roc_curve_for_fold(res, fold, clf_idx, target):


def roc_curve_vertical_average(curves, samples=10):
if not len(curves):
raise ValueError("No curves")
fpr_sample = numpy.linspace(0.0, 1.0, samples)
tpr_samples = []
for fpr, tpr, _ in curves:
Expand All @@ -680,6 +698,8 @@ def roc_curve_vertical_average(curves, samples=10):


def roc_curve_threshold_average(curves, thresh_samples):
if not len(curves):
raise ValueError("No curves")
fpr_samples, tpr_samples = [], []
for fpr, tpr, thresh in curves:
ind = numpy.searchsorted(thresh[::-1], thresh_samples, side="left")
Expand Down
57 changes: 57 additions & 0 deletions Orange/widgets/evaluate/tests/test_owrocanalysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
import numpy

import Orange.data
import Orange.evaluation
import Orange.classification

from Orange.widgets.evaluate import owrocanalysis


class TestROC(unittest.TestCase):
def test_ROCData_from_results(self):
data = Orange.data.Table("iris")
learners = [
Orange.classification.MajorityLearner(),
Orange.classification.LogisticRegressionLearner(),
Orange.classification.TreeLearner()
]
res = Orange.evaluation.CrossValidation(data, learners, k=10)

for i, _ in enumerate(learners):
for c in range(len(data.domain.class_var.values)):
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
self.assertTrue(rocdata.merged.is_valid)
self.assertEqual(len(rocdata.folds), 10)
self.assertTrue(all(c.is_valid for c in rocdata.folds))
self.assertTrue(rocdata.avg_vertical.is_valid)
self.assertTrue(rocdata.avg_threshold.is_valid)

data = data[numpy.random.choice(len(data), size=20)]
res = Orange.evaluation.LeaveOneOut(data, learners)

for i, _ in enumerate(learners):
for c in range(len(data.domain.class_var.values)):
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
self.assertTrue(rocdata.merged.is_valid)
self.assertEqual(len(rocdata.folds), 20)
# all individual fold curves and averaged curve data
# should be invalid
self.assertTrue(all(not c.is_valid for c in rocdata.folds))
self.assertFalse(rocdata.avg_vertical.is_valid)
self.assertFalse(rocdata.avg_threshold.is_valid)

# equivalent test to the LeaveOneOut but from a slightly different
# constructed Orange.evaluation.Results
res = Orange.evaluation.CrossValidation(data, learners, k=20)

for i, _ in enumerate(learners):
for c in range(len(data.domain.class_var.values)):
rocdata = owrocanalysis.ROCData_from_results(res, i, target=c)
self.assertTrue(rocdata.merged.is_valid)
self.assertEqual(len(rocdata.folds), 20)
# all individual fold curves and averaged curve data
# should be invalid
self.assertTrue(all(not c.is_valid for c in rocdata.folds))
self.assertFalse(rocdata.avg_vertical.is_valid)
self.assertFalse(rocdata.avg_threshold.is_valid)