Skip to content

Commit 5a5ebcf

Browse files
authored
Merge pull request #6618 from janezd/roc-output
ROC: Output a model with a new operating threshold
2 parents 46a9b28 + 410d76a commit 5a5ebcf

File tree

9 files changed

+243
-102
lines changed

9 files changed

+243
-102
lines changed

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from Orange.widgets import widget, gui, settings
1717
from Orange.widgets.evaluate.contexthandlers import \
1818
EvaluationResultsContextHandler
19-
from Orange.widgets.evaluate.utils import results_for_preview
19+
from Orange.widgets.evaluate.utils import results_for_preview, \
20+
check_can_calibrate
2021
from Orange.widgets.utils import colorpalettes
2122
from Orange.widgets.utils.widgetpreview import WidgetPreview
2223
from Orange.widgets.visualize.utils.customizableplot import \
@@ -486,23 +487,11 @@ def commit(self):
486487
wrapped = None
487488
results = self.results
488489
if results is not None:
489-
problems = [
490-
msg for condition, msg in (
491-
(results.folds is not None and len(results.folds) > 1,
492-
"each training data sample produces a different model"),
493-
(results.models is None,
494-
"test results do not contain stored models - try testing "
495-
"on separate data or on training data"),
496-
(len(self.selected_classifiers) != 1,
497-
"select a single model - the widget can output only one"),
498-
(self.score != 0 and len(results.domain.class_var.values) != 2,
499-
"cannot calibrate non-binary classes"))
500-
if condition]
501-
if len(problems) == 1:
502-
self.Information.no_output(problems[0])
503-
elif problems:
504-
self.Information.no_output(
505-
"".join(f"\n - {problem}" for problem in problems))
490+
problems = check_can_calibrate(
491+
self.results, self.selected_classifiers,
492+
require_binary=self.score != 0)
493+
if problems:
494+
self.Information.no_output(problems)
506495
else:
507496
clsf_idx = self.selected_classifiers[0]
508497
model = results.models[0, clsf_idx]

Orange/widgets/evaluate/owliftcurve.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from Orange.widgets import widget, gui, settings
2121
from Orange.widgets.evaluate.contexthandlers import \
2222
EvaluationResultsContextHandler
23-
from Orange.widgets.evaluate.utils import check_results_adequacy
23+
from Orange.widgets.evaluate.utils import check_results_adequacy, \
24+
check_can_calibrate
2425
from Orange.widgets.utils import colorpalettes
2526
from Orange.widgets.utils.widgetpreview import WidgetPreview
2627
from Orange.widgets.visualize.utils.customizableplot import Updater, \
@@ -267,7 +268,7 @@ def _initialize(self, results):
267268
item = self.classifiers_list_box.item(i)
268269
item.setIcon(colorpalettes.ColorIcon(color))
269270

270-
class_values = results.data.domain.class_var.values
271+
class_values = results.domain.class_var.values
271272
self.target_cb.addItems(class_values)
272273
if class_values:
273274
self.target_index = 0
@@ -493,23 +494,10 @@ def commit(self):
493494
wrapped = None
494495
results = self.results
495496
if results is not None:
496-
problems = [
497-
msg for condition, msg in (
498-
(results.folds is not None and len(results.folds) > 1,
499-
"each training data sample produces a different model"),
500-
(results.models is None,
501-
"test results do not contain stored models - try testing "
502-
"on separate data or on training data"),
503-
(len(self.selected_classifiers) != 1,
504-
"select a single model - the widget can output only one"),
505-
(len(results.domain.class_var.values) != 2,
506-
"cannot calibrate non-binary classes"))
507-
if condition]
508-
if len(problems) == 1:
509-
self.Information.no_output(problems[0])
510-
elif problems:
511-
self.Information.no_output(
512-
"".join(f"\n - {problem}" for problem in problems))
497+
problems = check_can_calibrate(
498+
self.results, self.selected_classifiers)
499+
if problems:
500+
self.Information.no_output(problems)
513501
else:
514502
clsf_idx = self.selected_classifiers[0]
515503
model = results.models[0, clsf_idx]

Orange/widgets/evaluate/owrocanalysis.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,21 @@
1313
import pyqtgraph as pg
1414

1515
import Orange
16+
from Orange.base import Model
17+
from Orange.classification import ThresholdClassifier
18+
from Orange.evaluation.testing import Results
1619
from Orange.widgets import widget, gui, settings
1720
from Orange.widgets.evaluate.contexthandlers import \
1821
EvaluationResultsContextHandler
19-
from Orange.widgets.evaluate.utils import check_results_adequacy
22+
from Orange.widgets.evaluate.utils import check_results_adequacy, \
23+
check_can_calibrate
2024
from Orange.widgets.utils import colorpalettes
2125
from Orange.widgets.utils.widgetpreview import WidgetPreview
2226
from Orange.widgets.visualize.utils.plotutils import GraphicsView, PlotItem
23-
from Orange.widgets.widget import Input
27+
from Orange.widgets.widget import Input, Output, Msg
2428
from Orange.widgets import report
2529

2630
from Orange.widgets.evaluate.utils import results_for_preview
27-
from Orange.evaluation.testing import Results
2831

2932

3033
#: Points on a ROC curve
@@ -305,6 +308,12 @@ class OWROCAnalysis(widget.OWWidget):
305308
class Inputs:
306309
evaluation_results = Input("Evaluation Results", Orange.evaluation.Results)
307310

311+
class Outputs:
312+
calibrated_model = Output("Calibrated Model", Model)
313+
314+
class Information(widget.OWWidget.Information):
315+
no_output = Msg("Can't output a model: {}")
316+
308317
buttons_area_orientation = None
309318
settingsHandler = EvaluationResultsContextHandler()
310319
target_index = settings.ContextSetting(0)
@@ -466,7 +475,7 @@ def _initialize(self, results):
466475
listitem = self.classifiers_list_box.item(i)
467476
listitem.setIcon(colorpalettes.ColorIcon(self.colors[i]))
468477

469-
class_var = results.data.domain.class_var
478+
class_var = results.domain.class_var
470479
self.target_cb.addItems(class_var.values)
471480
self.target_index = 0
472481
self._set_target_prior()
@@ -620,8 +629,7 @@ def no_averaging():
620629
pen.setCosmetic(True)
621630
self.plot.plot([0, 1], [0, 1], pen=pen, antialias=True)
622631

623-
if self.roc_averaging == OWROCAnalysis.Merge:
624-
self._update_perf_line()
632+
self._update_perf_line()
625633

626634
self._update_axes_ticks()
627635

@@ -730,8 +738,7 @@ def _on_target_prior_changed(self):
730738
self._on_display_perf_line_changed()
731739

732740
def _on_display_perf_line_changed(self):
733-
if self.roc_averaging == OWROCAnalysis.Merge:
734-
self._update_perf_line()
741+
self._update_perf_line()
735742

736743
if self.perf_line is not None:
737744
self.perf_line.setVisible(self.display_perf_line)
@@ -745,9 +752,12 @@ def _replot(self):
745752
self._setup_plot()
746753

747754
def _update_perf_line(self):
748-
if self._perf_line is None:
755+
756+
if self._perf_line is None or self.roc_averaging != OWROCAnalysis.Merge:
757+
self._update_output(None)
749758
return
750759

760+
ind = None
751761
self._perf_line.setVisible(self.display_perf_line)
752762
if self.display_perf_line:
753763
m = roc_iso_performance_slope(
@@ -762,6 +772,26 @@ def _update_perf_line(self):
762772
else:
763773
self._perf_line.setVisible(False)
764774

775+
self._update_output(None if ind is None else hull.thresholds[ind[0]])
776+
777+
def _update_output(self, threshold):
778+
self.Information.no_output.clear()
779+
780+
if threshold is None:
781+
self.Outputs.calibrated_model.send(None)
782+
return
783+
784+
problems = check_can_calibrate(self.results, self.selected_classifiers)
785+
if problems:
786+
self.Information.no_output(problems)
787+
self.Outputs.calibrated_model.send(None)
788+
return
789+
790+
model = ThresholdClassifier(
791+
self.results.models[0][self.selected_classifiers[0]],
792+
threshold)
793+
self.Outputs.calibrated_model.send(model)
794+
765795
def onDeleteWidget(self):
766796
self.clear()
767797

Orange/widgets/evaluate/tests/base.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,59 @@
1+
from unittest.mock import Mock
2+
3+
import numpy as np
4+
15
from Orange import classification, evaluation
2-
from Orange.data import Table
6+
from Orange.data import Table, Domain, DiscreteVariable
7+
from Orange.evaluation import Results
8+
from Orange.evaluation.performance_curves import Curves
9+
from Orange.tests import test_filename
10+
11+
from Orange.widgets.tests.base import WidgetTest
12+
313

14+
class EvaluateTest(WidgetTest):
15+
def setUp(self):
16+
super().setUp()
17+
18+
n, p = (0, 1)
19+
actual, probs = np.array([
20+
(p, .8), (n, .7), (p, .6), (p, .55), (p, .54), (n, .53), (n, .52),
21+
(p, .51), (n, .505), (p, .4), (n, .39), (p, .38), (n, .37),
22+
(n, .36), (n, .35), (p, .34), (n, .33), (p, .30), (n, .1)]).T
23+
self.curves = Curves(actual, probs)
24+
probs2 = (probs + 1) / 2
25+
self.curves2 = Curves(actual, probs2)
26+
pred = probs > 0.5
27+
pred2 = probs2 > 0.5
28+
probs = np.vstack((1 - probs, probs)).T
29+
probs2 = np.vstack((1 - probs2, probs2)).T
30+
domain = Domain([], DiscreteVariable("y", values=("a", "b")))
31+
self.results = Results(
32+
domain=domain,
33+
actual=actual,
34+
folds=np.array([Ellipsis]),
35+
models=np.array([[Mock(), Mock()]]),
36+
row_indices=np.arange(19),
37+
predicted=np.array((pred, pred2)),
38+
probabilities=np.array([probs, probs2]))
39+
40+
self.lenses = data = Table(test_filename("datasets/lenses.tab"))
41+
majority = classification.MajorityLearner()
42+
majority.name = "majority"
43+
knn3 = classification.KNNLearner(n_neighbors=3)
44+
knn3.name = "knn-3"
45+
knn1 = classification.KNNLearner(n_neighbors=1)
46+
knn1.name = "knn-1"
47+
self.lenses_results = evaluation.TestOnTestData(
48+
store_data=True, store_models=True)(
49+
data=data[::2], test_data=data[1::2],
50+
learners=[majority, knn3, knn1])
51+
self.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]
452

5-
class EvaluateTest:
653
def test_many_evaluation_results(self):
54+
if not hasattr(self, "widget"):
55+
return
56+
757
data = Table("iris")
858
learners = [
959
classification.MajorityLearner(),

Orange/widgets/evaluate/tests/test_owcalibrationplot.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,15 @@
1111

1212
from orangewidget.utils.combobox import qcombobox_emit_activated
1313

14-
from Orange.data import Table, DiscreteVariable, Domain, ContinuousVariable
15-
import Orange.evaluation
16-
import Orange.classification
17-
from Orange.evaluation import Results
14+
from Orange.data import Domain, ContinuousVariable
1815
from Orange.evaluation.performance_curves import Curves
1916
from Orange.widgets.evaluate.tests.base import EvaluateTest
2017
from Orange.widgets.evaluate.owcalibrationplot import OWCalibrationPlot
21-
from Orange.widgets.tests.base import WidgetTest
22-
from Orange.tests import test_filename
2318

2419

25-
class TestOWCalibrationPlot(WidgetTest, EvaluateTest):
20+
class TestOWCalibrationPlot(EvaluateTest):
2621
def setUp(self):
2722
super().setUp()
28-
29-
n, p = (0, 1)
30-
actual, probs = np.array([
31-
(p, .8), (n, .7), (p, .6), (p, .55), (p, .54), (n, .53), (n, .52),
32-
(p, .51), (n, .505), (p, .4), (n, .39), (p, .38), (n, .37),
33-
(n, .36), (n, .35), (p, .34), (n, .33), (p, .30), (n, .1)]).T
34-
self.curves = Curves(actual, probs)
35-
probs2 = (probs + 0.5) / 2 + 1
36-
self.curves2 = Curves(actual, probs2)
37-
pred = probs > 0.5
38-
pred2 = probs2 > 0.5
39-
probs = np.vstack((1 - probs, probs)).T
40-
probs2 = np.vstack((1 - probs2, probs2)).T
41-
domain = Domain([], DiscreteVariable("y", values=("a", "b")))
42-
self.results = Results(
43-
domain=domain,
44-
actual=actual,
45-
folds=np.array([Ellipsis]),
46-
models=np.array([[Mock(), Mock()]]),
47-
row_indices=np.arange(19),
48-
predicted=np.array((pred, pred2)),
49-
probabilities=np.array([probs, probs2]))
50-
51-
self.lenses = data = Table(test_filename("datasets/lenses.tab"))
52-
majority = Orange.classification.MajorityLearner()
53-
majority.name = "majority"
54-
knn3 = Orange.classification.KNNLearner(n_neighbors=3)
55-
knn3.name = "knn-3"
56-
knn1 = Orange.classification.KNNLearner(n_neighbors=1)
57-
knn1.name = "knn-1"
58-
self.lenses_results = Orange.evaluation.TestOnTestData(
59-
store_data=True, store_models=True)(
60-
data=data[::2], test_data=data[1::2],
61-
learners=[majority, knn3, knn1])
62-
self.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]
63-
6423
self.widget = self.create_widget(OWCalibrationPlot) # type: OWCalibrationPlot
6524
warnings.filterwarnings("ignore", ".*", ConvergenceWarning)
6625

@@ -382,6 +341,8 @@ def test_threshold_flips_on_two_classes(self):
382341
@patch("Orange.widgets.evaluate.owcalibrationplot.CalibratedLearner")
383342
def test_apply_no_output(self, *_):
384343
"""Test no output warnings"""
344+
# Similar to test_owcalibrationplot, but just a little different, hence
345+
# pylint: disable=duplicate-code
385346
widget = self.widget
386347
model_list = widget.controls.selected_classifiers
387348

@@ -395,7 +356,7 @@ def test_apply_no_output(self, *_):
395356
multiple_selected:
396357
"select a single model - the widget can output only one",
397358
non_binary_class:
398-
"cannot calibrate non-binary classes"}
359+
"cannot calibrate non-binary models"}
399360

400361
def test_shown(shown):
401362
widget_msg = widget.Information.no_output

Orange/widgets/evaluate/tests/test_owliftcurve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
SKIP_REASON = "Only test precision-recall with scikit-learn>=1.1.1"
3232

3333

34-
class TestOWLiftCurve(WidgetTest, EvaluateTest):
34+
class TestOWLiftCurve(EvaluateTest):
3535
@classmethod
3636
def setUpClass(cls):
3737
super().setUpClass()

0 commit comments

Comments
 (0)