Skip to content

Commit 2312bf4

Browse files
committed
Calibration plot: Output selected model
1 parent 244a56e commit 2312bf4

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
import pyqtgraph as pg
99

10+
from Orange.classification import ModelWithThreshold
1011
from Orange.evaluation import Results
1112
from Orange.widgets import widget, gui, settings
1213
from Orange.widgets.evaluate.utils import \
1314
check_results_adequacy, results_for_preview
1415
from Orange.widgets.utils import colorpalette, colorbrewer
1516
from Orange.widgets.utils.widgetpreview import WidgetPreview
16-
from Orange.widgets.widget import Input
17+
from Orange.widgets.widget import Input, Output, Msg
1718
from Orange.widgets import report
1819

1920

@@ -98,16 +99,31 @@ class OWCalibrationPlot(widget.OWWidget):
9899
class Inputs:
99100
evaluation_results = Input("Evaluation Results", Results)
100101

102+
class Outputs:
103+
calibrated_model = Output("Calibrated Model", ModelWithThreshold)
104+
101105
class Warning(widget.OWWidget.Warning):
102106
empty_input = widget.Msg(
103107
"Empty result on input. Nothing to display.")
104108

109+
class Information(widget.OWWidget.Information):
110+
no_out = "Can't output a model: "
111+
no_output_multiple_folds = Msg(
112+
no_out + "every training data sample produced a different model")
113+
no_output_no_models = Msg(
114+
no_out + "test results do not contain stored models;\n"
115+
"try testing on separate data or on training data")
116+
no_output_multiple_selected = Msg(
117+
no_out + "select a single model - the widget can output only one")
118+
119+
105120
target_index = settings.Setting(0)
106121
selected_classifiers = settings.Setting([])
107122
score = settings.Setting(0)
108123
fold_curves = settings.Setting(False)
109124
display_rug = settings.Setting(True)
110125
threshold = settings.Setting(0.5)
126+
auto_commit = settings.Setting(True)
111127

112128
graph_name = "plot"
113129

@@ -136,7 +152,7 @@ def __init__(self):
136152
box="Classifier", selectionMode=QListWidget.ExtendedSelection,
137153
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred),
138154
sizeHint=QSize(150, 40),
139-
callback=self._replot)
155+
callback=self._on_selection_changed)
140156

141157
box = gui.vBox(self.controlArea, "Metrics")
142158
combo = gui.comboBox(
@@ -153,6 +169,9 @@ def __init__(self):
153169
box = gui.widgetBox(self.controlArea, "Info")
154170
self.info_label = gui.widgetLabel(box)
155171

172+
gui.auto_commit(
173+
self.controlArea, self, "auto_commit", "Apply", commit=self.apply)
174+
156175
self.plotview = pg.GraphicsView(background="w")
157176
self.plot = pg.PlotItem(enableMenu=False)
158177
self.plot.setMouseEnabled(False, False)
@@ -182,6 +201,7 @@ def set_results(self, results):
182201
if self.results is not None:
183202
self._initialize(results)
184203
self._replot()
204+
self.apply()
185205

186206
def clear(self):
187207
self.plot.clear()
@@ -326,13 +346,16 @@ def _replot(self):
326346
def _on_display_rug_changed(self):
327347
self._replot()
328348

349+
def _on_selection_changed(self):
350+
self._replot()
351+
self.apply()
352+
329353
def threshold_change(self):
330354
self.threshold = round(self.line.pos().x(), 2)
331355
self.line.setPos(self.threshold)
332356
self._update_info()
333357

334358
def _update_info(self):
335-
336359
text = f"""<table>
337360
<tr>
338361
<th align='right'>Threshold: p=</th>
@@ -357,7 +380,25 @@ def _update_info(self):
357380
self.info_label.setText(text)
358381

359382
def threshold_change_done(self):
360-
...
383+
self.apply()
384+
385+
def apply(self):
386+
info = self.Information
387+
wrapped = None
388+
if self.results is not None:
389+
problems = {
390+
info.no_output_multiple_folds: len(self.results.folds) > 1,
391+
info.no_output_no_models: self.results.models is None,
392+
info.no_output_multiple_selected:
393+
len(self.selected_classifiers) != 1}
394+
if not any(problems.values()):
395+
model = self.results.models[0][self.selected_classifiers[0]]
396+
wrapped = ModelWithThreshold(model, self.threshold)
397+
398+
self.Outputs.calibrated_model.send(wrapped)
399+
for info, shown in problems.items():
400+
if info.is_shown() != shown:
401+
info(shown=shown)
361402

362403
def send_report(self):
363404
if self.results is None:

0 commit comments

Comments
 (0)