Skip to content

Commit 2edcb39

Browse files
committed
OWCalibration Plot: Unit tests and some fixes
1 parent 6ac1db1 commit 2edcb39

File tree

5 files changed

+575
-38
lines changed

5 files changed

+575
-38
lines changed

Orange/evaluation/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def set_or_raise(value, exp_values, msg):
171171
"mismatching number of class values")
172172
nmethods = set_or_raise(
173173
nmethods, [learners is not None and len(learners),
174-
models is not None and len(models),
174+
models is not None and models.shape[1],
175175
failed is not None and len(failed),
176176
predicted is not None and predicted.shape[0],
177177
probabilities is not None and probabilities.shape[0]],
@@ -365,7 +365,7 @@ def __new__(cls,
365365
"and train_data are omitted")
366366
return self
367367

368-
warn("calling Validation's constructor with data and learners"
368+
warn("calling Validation's constructor with data and learners "
369369
"is deprecated;\nconstruct an instance and call it",
370370
DeprecationWarning, stacklevel=2)
371371

Orange/tests/test_evaluation_testing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def setUp(self):
756756
self.row_indices = np.arange(100)
757757
self.folds = (range(50), range(10, 60)), (range(50, 100), range(50))
758758
self.learners = [MajorityLearner(), MajorityLearner()]
759-
self.models = [Mock(), Mock()]
759+
self.models = np.array([[Mock(), Mock()]])
760760
self.predicted = np.zeros((2, 100))
761761
self.probabilities = np.zeros((2, 100, 3))
762762
self.failed = [False, True]

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
from Orange.widgets import widget, gui, settings
1515
from Orange.widgets.evaluate.contexthandlers import \
1616
EvaluationResultsContextHandler
17-
from Orange.widgets.evaluate.utils import \
18-
check_results_adequacy, results_for_preview
17+
from Orange.widgets.evaluate.utils import results_for_preview
1918
from Orange.widgets.utils import colorpalette, colorbrewer
2019
from Orange.widgets.utils.widgetpreview import WidgetPreview
2120
from Orange.widgets.widget import Input, Output, Msg
@@ -72,7 +71,8 @@ class Inputs:
7271
class Outputs:
7372
calibrated_model = Output("Calibrated Model", Model)
7473

75-
class Warning(widget.OWWidget.Warning):
74+
class Error(widget.OWWidget.Error):
75+
non_discrete_target = Msg("Calibration plot requires a discrete target")
7676
empty_input = widget.Msg("Empty result on input. Nothing to display.")
7777

7878
class Information(widget.OWWidget.Information):
@@ -84,7 +84,8 @@ class Information(widget.OWWidget.Information):
8484
"try testing on separate data or on training data")
8585
no_output_multiple_selected = Msg(
8686
no_out + "select a single model - the widget can output only one")
87-
non_binary_class = Msg(no_out + "cannot calibrate non-binary classes")
87+
no_output_non_binary_class = Msg(
88+
no_out + "cannot calibrate non-binary classes")
8889

8990
settingsHandler = EvaluationResultsContextHandler()
9091
target_index = settings.ContextSetting(0)
@@ -145,8 +146,8 @@ def __init__(self):
145146
btnLabels=("Sigmoid calibration", "Isotonic calibration"),
146147
label="Output model calibration", callback=self.apply)
147148

148-
box = gui.widgetBox(self.controlArea, "Info")
149-
self.info_label = gui.widgetLabel(box)
149+
self.info_box = gui.widgetBox(self.controlArea, "Info")
150+
self.info_label = gui.widgetLabel(self.info_box)
150151

151152
gui.auto_commit(
152153
self.controlArea, self, "auto_commit", "Apply", commit=self.apply)
@@ -159,6 +160,10 @@ def __init__(self):
159160
for axis_name in ("bottom", "left"):
160161
axis = self.plot.getAxis(axis_name)
161162
axis.setPen(pg.mkPen(color=0.0))
163+
# Remove the condition (that is, allow setting this for bottom
164+
# axis) when pyqtgraph is fixed
165+
# Issue: https://github.com/pyqtgraph/pyqtgraph/issues/930
166+
# Pull request: https://github.com/pyqtgraph/pyqtgraph/pull/932
162167
if axis_name != "bottom": # remove if when pyqtgraph is fixed
163168
axis.setStyle(stopAxisAtTick=(True, True))
164169

@@ -172,11 +177,14 @@ def __init__(self):
172177
def set_results(self, results):
173178
self.closeContext()
174179
self.clear()
175-
results = check_results_adequacy(results, self.Error, check_nan=False)
180+
self.Error.clear()
181+
self.Information.clear()
182+
if results is not None and not results.domain.has_discrete_class:
183+
self.Error.non_discrete_target()
184+
results = None
176185
if results is not None and not results.actual.size:
177-
self.Warning.empty_input()
178-
else:
179-
self.Warning.empty_input.clear()
186+
self.Error.empty_input()
187+
results = None
180188
self.results = results
181189
if self.results is not None:
182190
self._initialize(results)
@@ -219,8 +227,10 @@ def _set_explanation(self):
219227

220228
if self.score == 0:
221229
self.controls.output_calibration.show()
230+
self.info_box.hide()
222231
else:
223232
self.controls.output_calibration.hide()
233+
self.info_box.show()
224234

225235
axis = self.plot.getAxis("bottom")
226236
axis.setLabel("Predicted probability" if self.score == 0
@@ -230,23 +240,23 @@ def _set_explanation(self):
230240
axis.setLabel(Metrics[self.score].name)
231241

232242
def _initialize(self, results):
233-
N = len(results.predicted)
243+
n = len(results.predicted)
234244
names = getattr(results, "learner_names", None)
235245
if names is None:
236-
names = ["#{}".format(i + 1) for i in range(N)]
246+
names = ["#{}".format(i + 1) for i in range(n)]
237247

238248
self.classifier_names = names
239249
scheme = colorbrewer.colorSchemes["qualitative"]["Dark2"]
240-
if N > len(scheme):
250+
if n > len(scheme):
241251
scheme = colorpalette.DefaultRGBColors
242-
self.colors = colorpalette.ColorPaletteGenerator(N, scheme)
252+
self.colors = colorpalette.ColorPaletteGenerator(n, scheme)
243253

244-
for i in range(N):
254+
for i in range(n):
245255
item = self.classifiers_list_box.item(i)
246256
item.setIcon(colorpalette.ColorPixmap(self.colors[i]))
247257

248-
self.selected_classifiers = list(range(N))
249-
self.target_cb.addItems(results.data.domain.class_var.values)
258+
self.selected_classifiers = list(range(n))
259+
self.target_cb.addItems(results.domain.class_var.values)
250260

251261
def _rug(self, data, pen_args):
252262
color = pen_args["pen"].color()
@@ -288,7 +298,6 @@ def _prob_curve(self, ytrue, probs, pen_args):
288298
y = np.full(100, xmax)
289299

290300
self.plot.plot(x, y, symbol="+", symbolSize=4, **pen_args)
291-
self.plot.plot([0, 1], [0, 1], antialias=True)
292301
return x, (y, )
293302

294303
def _setup_plot(self):
@@ -326,6 +335,9 @@ def _setup_plot(self):
326335
self.plot_metrics(Curves(fold_ytrue, fold_probs),
327336
metrics, pen_args)
328337

338+
if self.score == 0:
339+
self.plot.plot([0, 1], [0, 1], antialias=True)
340+
329341
def _replot(self):
330342
self.plot.clear()
331343
if self.results is not None:
@@ -379,7 +391,7 @@ def _update_info(self):
379391
for curve in curves)
380392
text += "</tr>"
381393
text += "<table>"
382-
self.info_label.setText(text)
394+
self.info_label.setText(text)
383395

384396
def threshold_change_done(self):
385397
self.apply()
@@ -395,7 +407,7 @@ def apply(self):
395407
info.no_output_no_models: results.models is None,
396408
info.no_output_multiple_selected:
397409
len(self.selected_classifiers) != 1,
398-
info.non_binary_class:
410+
info.no_output_non_binary_class:
399411
self.score != 0
400412
and len(results.domain.class_var.values) != 2}
401413
if not any(problems.values()):
@@ -419,11 +431,19 @@ def apply(self):
419431
def send_report(self):
420432
if self.results is None:
421433
return
434+
self.report_items((
435+
("Target class", self.target_cb.currentText()),
436+
("Output model calibration",
437+
self.score == 0 and self.controls.score.currentText()),
438+
))
422439
caption = report.list_legend(self.classifiers_list_box,
423440
self.selected_classifiers)
424-
self.report_items((("Target class", self.target_cb.currentText()),))
425441
self.report_plot()
426442
self.report_caption(caption)
443+
self.report_caption(self.controls.score.currentText())
444+
445+
if self.score != 0:
446+
self.report_raw(self.info_label.text())
427447

428448

429449
def gaussian_smoother(x, y, sigma=1.0):

Orange/widgets/evaluate/tests/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ def test_many_evaluation_results(self):
1717
classification.NaiveBayesLearner(),
1818
classification.SGDClassificationLearner()
1919
]
20-
res = evaluation.CrossValidation(data, learners, k=2, store_data=True)
20+
res = evaluation.CrossValidation(k=2, store_data=True)(data, learners)
2121
# this is a mixin; pylint: disable=no-member
2222
self.send_signal("Evaluation Results", res)

0 commit comments

Comments
 (0)