Skip to content

Commit 47308ca

Browse files
committed
Calibration plot: Test missing probabilities and single classes
1 parent 2edcb39 commit 47308ca

File tree

2 files changed

+207
-89
lines changed

2 files changed

+207
-89
lines changed

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 96 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,23 @@ class Outputs:
7474
class Error(widget.OWWidget.Error):
7575
non_discrete_target = Msg("Calibration plot requires a discrete target")
7676
empty_input = widget.Msg("Empty result on input. Nothing to display.")
77+
nan_classes = \
78+
widget.Msg("Remove test data instances with unknown classes")
79+
all_target_class = widget.Msg(
80+
"All data instances belong to target class")
81+
no_target_class = widget.Msg(
82+
"No data instances belong to target class")
83+
84+
class Warning(widget.OWWidget.Warning):
85+
omitted_folds = widget.Msg(
86+
"Test folds where all data belongs to (non)-target are not shown")
87+
omitted_nan_prob_points = widget.Msg(
88+
"Instance for which the model couldn't compute probabilities are"
89+
"skipped")
90+
no_valid_data = widget.Msg("No valid data for model(s) {}")
7791

7892
class Information(widget.OWWidget.Information):
79-
no_out = "Can't output a model: "
80-
no_output_multiple_folds = Msg(
81-
no_out + "each training data sample produces a different model")
82-
no_output_no_models = Msg(
83-
no_out + "test results do not contain stored models;\n"
84-
"try testing on separate data or on training data")
85-
no_output_multiple_selected = Msg(
86-
no_out + "select a single model - the widget can output only one")
87-
no_output_non_binary_class = Msg(
88-
no_out + "cannot calibrate non-binary classes")
93+
no_output = Msg("Can't output a model: {}")
8994

9095
settingsHandler = EvaluationResultsContextHandler()
9196
target_index = settings.ContextSetting(0)
@@ -179,19 +184,23 @@ def set_results(self, results):
179184
self.clear()
180185
self.Error.clear()
181186
self.Information.clear()
182-
if results is not None and not results.domain.has_discrete_class:
183-
self.Error.non_discrete_target()
184-
results = None
185-
if results is not None and not results.actual.size:
186-
self.Error.empty_input()
187-
results = None
188-
self.results = results
189-
if self.results is not None:
190-
self._initialize(results)
191-
class_var = self.results.domain.class_var
192-
self.target_index = int(len(class_var.values) == 2)
193-
self.openContext(class_var, self.classifier_names)
194-
self._replot()
187+
188+
self.results = None
189+
if results is not None:
190+
if not results.domain.has_discrete_class:
191+
self.Error.non_discrete_target()
192+
elif not results.actual.size:
193+
self.Error.empty_input()
194+
elif np.any(np.isnan(results.actual)):
195+
self.Error.nan_classes()
196+
else:
197+
self.results = results
198+
self._initialize(results)
199+
class_var = self.results.domain.class_var
200+
self.target_index = int(len(class_var.values) == 2)
201+
self.openContext(class_var, self.classifier_names)
202+
self._replot()
203+
195204
self.apply()
196205

197206
def clear(self):
@@ -286,9 +295,6 @@ def plot_metrics(self, data, metrics, pen_args):
286295
return data.probs, ys
287296

288297
def _prob_curve(self, ytrue, probs, pen_args):
289-
if not probs.size:
290-
return None
291-
292298
xmin, xmax = probs.min(), probs.max()
293299
x = np.linspace(xmin, xmax, 100)
294300
if xmax != xmin:
@@ -307,16 +313,25 @@ def _setup_plot(self):
307313
plot_folds = self.fold_curves and results.folds is not None
308314
self.scores = []
309315

310-
ytrue = results.actual == target
316+
if not self._check_class_presence(results.actual == target):
317+
return
318+
319+
self.Warning.omitted_folds.clear()
320+
self.Warning.omitted_nan_prob_points.clear()
321+
no_valid_models = []
322+
shadow_width = 4 + 4 * plot_folds
311323
for clsf in self.selected_classifiers:
312-
probs = results.probabilities[clsf, :, target]
324+
data = Curves.from_results(results, target, clsf)
325+
if data.tot == 0: # all probabilities are nan
326+
no_valid_models.append(clsf)
327+
continue
328+
if data.tot != results.probabilities.shape[1]: # some are nan
329+
self.Warning.omitted_nan_prob_points()
330+
313331
color = self.colors[clsf]
314332
pen_args = dict(
315-
pen=pg.mkPen(color, width=1),
316-
shadowPen=pg.mkPen(color.lighter(160),
317-
width=4 + 4 * plot_folds),
318-
antiAlias=True)
319-
data = Curves(ytrue, probs)
333+
pen=pg.mkPen(color, width=1), antiAlias=True,
334+
shadowPen=pg.mkPen(color.lighter(160), width=shadow_width))
320335
self.scores.append(
321336
(self.classifier_names[clsf],
322337
self.plot_metrics(data, metrics, pen_args)))
@@ -330,28 +345,46 @@ def _setup_plot(self):
330345
antiAlias=True)
331346
for fold in range(len(results.folds)):
332347
fold_results = results.get_fold(fold)
333-
fold_ytrue = fold_results.actual == target
334-
fold_probs = fold_results.probabilities[clsf, :, target]
335-
self.plot_metrics(Curves(fold_ytrue, fold_probs),
336-
metrics, pen_args)
348+
fold_curve = Curves.from_results(fold_results, target, clsf)
349+
# Can't check this before: p and n can be 0 because of
350+
# nan probabilities
351+
if fold_curve.p * fold_curve.n == 0:
352+
self.Warning.omitted_folds()
353+
self.plot_metrics(fold_curve, metrics, pen_args)
354+
355+
if no_valid_models:
356+
self.Warning.no_valid_data(
357+
", ".join(self.classifier_names[i] for i in no_valid_models))
337358

338359
if self.score == 0:
339360
self.plot.plot([0, 1], [0, 1], antialias=True)
340-
341-
def _replot(self):
342-
self.plot.clear()
343-
if self.results is not None:
344-
self._setup_plot()
345-
if self.score != 0:
361+
else:
346362
self.line = pg.InfiniteLine(
347363
pos=self.threshold, movable=True,
348364
pen=pg.mkPen(color="k", style=Qt.DashLine, width=2),
349365
hoverPen=pg.mkPen(color="k", style=Qt.DashLine, width=3),
350366
bounds=(0, 1),
351367
)
352368
self.line.sigPositionChanged.connect(self.threshold_change)
353-
self.line.sigPositionChangeFinished.connect(self.threshold_change_done)
369+
self.line.sigPositionChangeFinished.connect(
370+
self.threshold_change_done)
354371
self.plot.addItem(self.line)
372+
373+
def _check_class_presence(self, ytrue):
374+
self.Error.all_target_class.clear()
375+
self.Error.no_target_class.clear()
376+
if np.max(ytrue) == 0:
377+
self.Error.no_target_class()
378+
return False
379+
if np.min(ytrue) == 1:
380+
self.Error.all_target_class()
381+
return False
382+
return True
383+
384+
def _replot(self):
385+
self.plot.clear()
386+
if self.results is not None:
387+
self._setup_plot()
355388
self._update_info()
356389

357390
def _on_display_rug_changed(self):
@@ -397,20 +430,28 @@ def threshold_change_done(self):
397430
self.apply()
398431

399432
def apply(self):
400-
info = self.Information
433+
self.Information.no_output.clear()
401434
wrapped = None
402-
problems = {}
403435
results = self.results
404436
if results is not None:
405-
problems = {
406-
info.no_output_multiple_folds: len(results.folds) > 1,
407-
info.no_output_no_models: results.models is None,
408-
info.no_output_multiple_selected:
409-
len(self.selected_classifiers) != 1,
410-
info.no_output_non_binary_class:
411-
self.score != 0
412-
and len(results.domain.class_var.values) != 2}
413-
if not any(problems.values()):
437+
problems = [
438+
msg for condition, msg in (
439+
(len(results.folds) > 1,
440+
"each training data sample produces a different model"),
441+
(results.models is None,
442+
"test results do not contain stored models - try testing on"
443+
"separate data or on training data"),
444+
(len(self.selected_classifiers) != 1,
445+
"select a single model - the widget can output only one"),
446+
(self.score != 0 and len(results.domain.class_var.values) != 2,
447+
"cannot calibrate non-binary classes"))
448+
if condition]
449+
if len(problems) == 1:
450+
self.Information.no_output(problems[0])
451+
elif problems:
452+
self.Information.no_output(
453+
"".join(f"\n - {problem}" for problem in problems))
454+
else:
414455
clsf_idx = self.selected_classifiers[0]
415456
model = results.models[0, clsf_idx]
416457
if self.score == 0:
@@ -424,9 +465,6 @@ def apply(self):
424465
wrapped = ThresholdClassifier(model, threshold)
425466

426467
self.Outputs.calibrated_model.send(wrapped)
427-
for info, shown in problems.items():
428-
if info.is_shown() != shown:
429-
info(shown=shown)
430468

431469
def send_report(self):
432470
if self.results is None:

Orange/widgets/evaluate/tests/test_owcalibrationplot.py

Lines changed: 111 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,6 @@
2121

2222

2323
class TestOWCalibrationPlot(WidgetTest, EvaluateTest):
24-
@classmethod
25-
def setUpClass(cls):
26-
super().setUpClass()
27-
cls.lenses = data = Table(test_filename("datasets/lenses.tab"))
28-
majority = Orange.classification.MajorityLearner()
29-
majority.name = "majority"
30-
knn3 = Orange.classification.KNNLearner(n_neighbors=3)
31-
knn3.name = "knn-3"
32-
knn1 = Orange.classification.KNNLearner(n_neighbors=1)
33-
knn1.name = "knn-1"
34-
cls.lenses_results = Orange.evaluation.TestOnTestData(
35-
store_data=True, store_models=True)(
36-
data=data[::2], test_data=data[1::2],
37-
learners=[majority, knn3, knn1])
38-
cls.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]
39-
4024
def setUp(self):
4125
super().setUp()
4226

@@ -56,12 +40,25 @@ def setUp(self):
5640
self.results = Results(
5741
domain=domain,
5842
actual=actual,
59-
folds=(Ellipsis, ),
43+
folds=np.array([Ellipsis]),
6044
models=np.array([[Mock(), Mock()]]),
6145
row_indices=np.arange(19),
6246
predicted=np.array((pred, pred2)),
6347
probabilities=np.array([probs, probs2]))
6448

49+
self.lenses = data = Table(test_filename("datasets/lenses.tab"))
50+
majority = Orange.classification.MajorityLearner()
51+
majority.name = "majority"
52+
knn3 = Orange.classification.KNNLearner(n_neighbors=3)
53+
knn3.name = "knn-3"
54+
knn1 = Orange.classification.KNNLearner(n_neighbors=1)
55+
knn1.name = "knn-1"
56+
self.lenses_results = Orange.evaluation.TestOnTestData(
57+
store_data=True, store_models=True)(
58+
data=data[::2], test_data=data[1::2],
59+
learners=[majority, knn3, knn1])
60+
self.lenses_results.learner_names = ["majority", "knn-3", "knn-1"]
61+
6562
self.widget = self.create_widget(OWCalibrationPlot) # type: OWCalibrationPlot
6663
warnings.filterwarnings("ignore", ".*", ConvergenceWarning)
6764

@@ -389,24 +386,31 @@ def test_apply_no_output(self, *_):
389386
widget = self.widget
390387
model_list = widget.controls.selected_classifiers
391388

392-
info = widget.Information
393-
infos = (info.no_output_multiple_folds,
394-
info.no_output_no_models,
395-
info.no_output_multiple_selected,
396-
info.no_output_non_binary_class)
397-
multiple_folds, no_models, multiple_selected, non_binary_class = infos
389+
multiple_folds, multiple_selected, no_models, non_binary_class = "abcd"
390+
messages = {
391+
multiple_folds:
392+
"each training data sample produces a different model",
393+
no_models:
394+
"test results do not contain stored models - try testing on"
395+
"separate data or on training data",
396+
multiple_selected:
397+
"select a single model - the widget can output only one",
398+
non_binary_class:
399+
"cannot calibrate non-binary classes"}
398400

399401
def test_shown(shown):
400-
for info in infos:
401-
self.assertEqual(
402-
info.is_shown(), info in shown,
403-
f"{info} is unexpectedly "
404-
f"{'' if info.is_shown() else 'not'} shown")
402+
widget_msg = widget.Information.no_output
405403
output = self.get_output(widget.Outputs.calibrated_model)
406-
if shown:
407-
self.assertIsNone(output)
408-
else:
404+
if not shown:
405+
self.assertFalse(widget_msg.is_shown())
409406
self.assertIsNotNone(output)
407+
else:
408+
self.assertTrue(widget_msg.is_shown())
409+
self.assertIsNone(output)
410+
for msg_id in shown:
411+
msg = messages[msg_id]
412+
self.assertIn(msg, widget_msg.formatted,
413+
f"{msg} not included in the message")
410414

411415
self.send_signal(widget.Inputs.evaluation_results, self.results)
412416
self._set_combo(widget.controls.score, 1) # CA
@@ -558,3 +562,79 @@ def test_report(self):
558562
widget = self.widget
559563
self.send_signal(widget.Inputs.evaluation_results, self.lenses_results)
560564
widget.send_report()
565+
566+
@patch("Orange.widgets.evaluate.owcalibrationplot.ThresholdClassifier")
567+
@patch("Orange.widgets.evaluate.owcalibrationplot.CalibratedLearner")
568+
def test_single_class(self, *_):
569+
"""Curves are not plotted if all data belongs to (non)-target"""
570+
def check_error(shown):
571+
for error in (errors.no_target_class, errors.all_target_class,
572+
errors.nan_classes):
573+
self.assertEqual(error.is_shown(), error is shown,
574+
f"{error} is unexpectedly"
575+
f"{'' if error.is_shown() else ' not'} shown")
576+
if shown is not None:
577+
self.assertEqual(len(widget.plot.items), 0)
578+
else:
579+
self.assertGreater(len(widget.plot.items), 0)
580+
581+
widget = self.widget
582+
errors = widget.Error
583+
widget.display_rug = True
584+
combo = widget.controls.score
585+
586+
original_actual = self.results.actual.copy()
587+
self.send_signal(widget.Inputs.evaluation_results, self.results)
588+
widget.selected_classifiers = [0]
589+
for idx in range(combo.count()):
590+
self._set_combo(combo, idx)
591+
self.results.actual[:] = 0
592+
self.send_signal(widget.Inputs.evaluation_results, self.results)
593+
check_error(errors.no_target_class)
594+
595+
self.results.actual[:] = 1
596+
self.send_signal(widget.Inputs.evaluation_results, self.results)
597+
check_error(errors.all_target_class)
598+
599+
self.results.actual[:] = original_actual
600+
self.results.actual[3] = np.nan
601+
self.send_signal(widget.Inputs.evaluation_results, self.results)
602+
check_error(errors.nan_classes)
603+
604+
self.results.actual[:] = original_actual
605+
self.send_signal(widget.Inputs.evaluation_results, self.results)
606+
check_error(None)
607+
608+
@patch("Orange.widgets.evaluate.owcalibrationplot.ThresholdClassifier")
609+
@patch("Orange.widgets.evaluate.owcalibrationplot.CalibratedLearner")
610+
def test_single_class_folds(self, *_):
611+
"""Curves for single-class folds are not plotted"""
612+
widget = self.widget
613+
widget.display_rug = False
614+
widget.fold_curves = False
615+
616+
results = self.lenses_results
617+
results.folds = [slice(0, 5), slice(5, 19)]
618+
results.models = results.models.repeat(2, axis=0)
619+
results.actual[:3] = 0
620+
results.probabilities[1, 3:5] = np.nan
621+
# after this, model 1 has just negative instances in fold 0
622+
self.send_signal(widget.Inputs.evaluation_results, results)
623+
self._set_combo(widget.controls.score, 1) # CA
624+
self.assertFalse(widget.Warning.omitted_folds.is_shown())
625+
widget.controls.fold_curves.click()
626+
self.assertTrue(widget.Warning.omitted_folds.is_shown())
627+
628+
@patch("Orange.widgets.evaluate.owcalibrationplot.ThresholdClassifier")
629+
@patch("Orange.widgets.evaluate.owcalibrationplot.CalibratedLearner")
630+
def test_warn_nan_probabilities(self, *_):
631+
"""Warn about omitted points with nan probabiities"""
632+
widget = self.widget
633+
widget.display_rug = False
634+
widget.fold_curves = False
635+
636+
self.results.probabilities[1, 3] = np.nan
637+
self.send_signal(widget.Inputs.evaluation_results, self.results)
638+
self.assertTrue(widget.Warning.omitted_nan_prob_points.is_shown())
639+
self._set_list_selection(widget.controls.selected_classifiers, [0, 2])
640+
self.assertFalse(widget.Warning.omitted_folds.is_shown())

0 commit comments

Comments
 (0)