Skip to content

Commit dd860da

Browse files
committed
Calibration plot: Output selected model
1 parent 3cfc122 commit dd860da

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77

88
import pyqtgraph as pg
99

10+
from Orange.classification import ModelWithThreshold
1011
from Orange.evaluation import Results
1112
from Orange.widgets import widget, gui, settings
1213
from Orange.widgets.evaluate.utils import \
1314
check_results_adequacy, results_for_preview
1415
from Orange.widgets.utils import colorpalette, colorbrewer
1516
from Orange.widgets.utils.widgetpreview import WidgetPreview
16-
from Orange.widgets.widget import Input
17+
from Orange.widgets.widget import Input, Output, Msg
1718
from Orange.widgets import report
1819

1920

@@ -98,16 +99,31 @@ class OWCalibrationPlot(widget.OWWidget):
9899
class Inputs:
99100
evaluation_results = Input("Evaluation Results", Results)
100101

102+
class Outputs:
103+
calibrated_model = Output("Calibrated Model", ModelWithThreshold)
104+
101105
class Warning(widget.OWWidget.Warning):
102106
empty_input = widget.Msg(
103107
"Empty result on input. Nothing to display.")
104108

109+
class Information(widget.OWWidget.Information):
110+
no_out = "Can't output a model: "
111+
no_output_multiple_folds = Msg(
112+
no_out + "every training data sample produced a different model")
113+
no_output_no_models = Msg(
114+
no_out + "test results do not contain stored models;\n"
115+
"try testing on separate data or on training data")
116+
no_output_multiple_selected = Msg(
117+
no_out + "select a single model - the widget can output only one")
118+
119+
105120
target_index = settings.Setting(0)
106121
selected_classifiers = settings.Setting([])
107122
score = settings.Setting(0)
108123
fold_curves = settings.Setting(False)
109124
display_rug = settings.Setting(True)
110125
threshold = settings.Setting(0.5)
126+
auto_commit = settings.Setting(True)
111127

112128
graph_name = "plot"
113129

@@ -135,7 +151,7 @@ def __init__(self):
135151
box="Classifier", selectionMode=QListWidget.ExtendedSelection,
136152
sizePolicy=(QSizePolicy.Preferred, QSizePolicy.Preferred),
137153
sizeHint=QSize(150, 40),
138-
callback=self._replot)
154+
callback=self._on_selection_changed)
139155

140156
box = gui.vBox(self.controlArea, "Metrics")
141157
combo = gui.comboBox(
@@ -152,6 +168,9 @@ def __init__(self):
152168
box = gui.widgetBox(self.controlArea, "Info")
153169
self.info_label = gui.widgetLabel(box)
154170

171+
gui.auto_commit(
172+
self.controlArea, self, "auto_commit", "Apply", commit=self.apply)
173+
155174
self.plotview = pg.GraphicsView(background="w")
156175
self.plot = pg.PlotItem(enableMenu=False)
157176
self.plot.setMouseEnabled(False, False)
@@ -181,6 +200,7 @@ def set_results(self, results):
181200
if self.results is not None:
182201
self._initialize(results)
183202
self._replot()
203+
self.apply()
184204

185205
def clear(self):
186206
self.plot.clear()
@@ -325,13 +345,16 @@ def _replot(self):
325345
def _on_display_rug_changed(self):
326346
self._replot()
327347

348+
def _on_selection_changed(self):
349+
self._replot()
350+
self.apply()
351+
328352
def threshold_change(self):
329353
self.threshold = round(self.line.pos().x(), 2)
330354
self.line.setPos(self.threshold)
331355
self._update_info()
332356

333357
def _update_info(self):
334-
335358
text = f"""<table>
336359
<tr>
337360
<th align='right'>Threshold: p=</th>
@@ -356,7 +379,25 @@ def _update_info(self):
356379
self.info_label.setText(text)
357380

358381
def threshold_change_done(self):
359-
...
382+
self.apply()
383+
384+
def apply(self):
385+
info = self.Information
386+
wrapped = None
387+
if self.results is not None:
388+
problems = {
389+
info.no_output_multiple_folds: len(self.results.folds) > 1,
390+
info.no_output_no_models: self.results.models is None,
391+
info.no_output_multiple_selected:
392+
len(self.selected_classifiers) != 1}
393+
if not any(problems.values()):
394+
model = self.results.models[0][self.selected_classifiers[0]]
395+
wrapped = ModelWithThreshold(model, self.threshold)
396+
397+
self.Outputs.calibrated_model.send(wrapped)
398+
for info, shown in problems.items():
399+
if info.is_shown() != shown:
400+
info(shown=shown)
360401

361402
def send_report(self):
362403
if self.results is None:

0 commit comments

Comments
 (0)