Skip to content

Commit 2049afa

Browse files
committed
Calibration plot: Test missing probabilities and single classes
1 parent 2edcb39 commit 2049afa

File tree

2 files changed

+208
-93
lines changed

2 files changed

+208
-93
lines changed

Orange/widgets/evaluate/owcalibrationplot.py

Lines changed: 97 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,23 @@ class Outputs:
7474
class Error(widget.OWWidget.Error):
7575
non_discrete_target = Msg("Calibration plot requires a discrete target")
7676
empty_input = widget.Msg("Empty result on input. Nothing to display.")
77+
nan_classes = \
78+
widget.Msg("Remove test data instances with unknown classes")
79+
all_target_class = widget.Msg(
80+
"All data instances belong to target class")
81+
no_target_class = widget.Msg(
82+
"No data instances belong to target class")
83+
84+
class Warning(widget.OWWidget.Warning):
85+
omitted_folds = widget.Msg(
86+
"Test folds where all data belongs to (non)-target are not shown")
87+
omitted_nan_prob_points = widget.Msg(
88+
"Instance for which the model couldn't compute probabilities are"
89+
"skipped")
90+
no_valid_data = widget.Msg("No valid data for model(s) {}")
7791

7892
class Information(widget.OWWidget.Information):
79-
no_out = "Can't output a model: "
80-
no_output_multiple_folds = Msg(
81-
no_out + "each training data sample produces a different model")
82-
no_output_no_models = Msg(
83-
no_out + "test results do not contain stored models;\n"
84-
"try testing on separate data or on training data")
85-
no_output_multiple_selected = Msg(
86-
no_out + "select a single model - the widget can output only one")
87-
no_output_non_binary_class = Msg(
88-
no_out + "cannot calibrate non-binary classes")
93+
no_output = Msg("Can't output a model: {}")
8994

9095
settingsHandler = EvaluationResultsContextHandler()
9196
target_index = settings.ContextSetting(0)
@@ -179,19 +184,23 @@ def set_results(self, results):
179184
self.clear()
180185
self.Error.clear()
181186
self.Information.clear()
182-
if results is not None and not results.domain.has_discrete_class:
183-
self.Error.non_discrete_target()
184-
results = None
185-
if results is not None and not results.actual.size:
186-
self.Error.empty_input()
187-
results = None
188-
self.results = results
189-
if self.results is not None:
190-
self._initialize(results)
191-
class_var = self.results.domain.class_var
192-
self.target_index = int(len(class_var.values) == 2)
193-
self.openContext(class_var, self.classifier_names)
194-
self._replot()
187+
188+
self.results = None
189+
if results is not None:
190+
if not results.domain.has_discrete_class:
191+
self.Error.non_discrete_target()
192+
elif not results.actual.size:
193+
self.Error.empty_input()
194+
elif np.any(np.isnan(results.actual)):
195+
self.Error.nan_classes()
196+
else:
197+
self.results = results
198+
self._initialize(results)
199+
class_var = self.results.domain.class_var
200+
self.target_index = int(len(class_var.values) == 2)
201+
self.openContext(class_var, self.classifier_names)
202+
self._replot()
203+
195204
self.apply()
196205

197206
def clear(self):
@@ -286,9 +295,6 @@ def plot_metrics(self, data, metrics, pen_args):
286295
return data.probs, ys
287296

288297
def _prob_curve(self, ytrue, probs, pen_args):
289-
if not probs.size:
290-
return None
291-
292298
xmin, xmax = probs.min(), probs.max()
293299
x = np.linspace(xmin, xmax, 100)
294300
if xmax != xmin:
@@ -307,16 +313,25 @@ def _setup_plot(self):
307313
plot_folds = self.fold_curves and results.folds is not None
308314
self.scores = []
309315

310-
ytrue = results.actual == target
316+
if not self._check_class_presence(results.actual == target):
317+
return
318+
319+
self.Warning.omitted_folds.clear()
320+
self.Warning.omitted_nan_prob_points.clear()
321+
no_valid_models = []
322+
shadow_width = 4 + 4 * plot_folds
311323
for clsf in self.selected_classifiers:
312-
probs = results.probabilities[clsf, :, target]
324+
data = Curves.from_results(results, target, clsf)
325+
if data.tot == 0: # all probabilities are nan
326+
no_valid_models.append(clsf)
327+
continue
328+
if data.tot != results.probabilities.shape[1]: # some are nan
329+
self.Warning.omitted_nan_prob_points()
330+
313331
color = self.colors[clsf]
314332
pen_args = dict(
315-
pen=pg.mkPen(color, width=1),
316-
shadowPen=pg.mkPen(color.lighter(160),
317-
width=4 + 4 * plot_folds),
318-
antiAlias=True)
319-
data = Curves(ytrue, probs)
333+
pen=pg.mkPen(color, width=1), antiAlias=True,
334+
shadowPen=pg.mkPen(color.lighter(160), width=shadow_width))
320335
self.scores.append(
321336
(self.classifier_names[clsf],
322337
self.plot_metrics(data, metrics, pen_args)))
@@ -330,28 +345,46 @@ def _setup_plot(self):
330345
antiAlias=True)
331346
for fold in range(len(results.folds)):
332347
fold_results = results.get_fold(fold)
333-
fold_ytrue = fold_results.actual == target
334-
fold_probs = fold_results.probabilities[clsf, :, target]
335-
self.plot_metrics(Curves(fold_ytrue, fold_probs),
336-
metrics, pen_args)
348+
fold_curve = Curves.from_results(fold_results, target, clsf)
349+
# Can't check this before: p and n can be 0 because of
350+
# nan probabilities
351+
if fold_curve.p * fold_curve.n == 0:
352+
self.Warning.omitted_folds()
353+
self.plot_metrics(fold_curve, metrics, pen_args)
354+
355+
if no_valid_models:
356+
self.Warning.no_valid_data(
357+
", ".join(self.classifier_names[i] for i in no_valid_models))
337358

338359
if self.score == 0:
339360
self.plot.plot([0, 1], [0, 1], antialias=True)
340-
341-
def _replot(self):
342-
self.plot.clear()
343-
if self.results is not None:
344-
self._setup_plot()
345-
if self.score != 0:
361+
else:
346362
self.line = pg.InfiniteLine(
347363
pos=self.threshold, movable=True,
348364
pen=pg.mkPen(color="k", style=Qt.DashLine, width=2),
349365
hoverPen=pg.mkPen(color="k", style=Qt.DashLine, width=3),
350366
bounds=(0, 1),
351367
)
352368
self.line.sigPositionChanged.connect(self.threshold_change)
353-
self.line.sigPositionChangeFinished.connect(self.threshold_change_done)
369+
self.line.sigPositionChangeFinished.connect(
370+
self.threshold_change_done)
354371
self.plot.addItem(self.line)
372+
373+
def _check_class_presence(self, ytrue):
374+
self.Error.all_target_class.clear()
375+
self.Error.no_target_class.clear()
376+
if np.max(ytrue) == 0:
377+
self.Error.no_target_class()
378+
return False
379+
if np.min(ytrue) == 1:
380+
self.Error.all_target_class()
381+
return False
382+
return True
383+
384+
def _replot(self):
385+
self.plot.clear()
386+
if self.results is not None:
387+
self._setup_plot()
355388
self._update_info()
356389

357390
def _on_display_rug_changed(self):
@@ -380,10 +413,7 @@ def _update_info(self):
380413
{"<td></td>".join(f"<td align='right'>{n}</td>"
381414
for n in short_names)}
382415
</tr>"""
383-
for name, probs_curves in self.scores:
384-
if probs_curves is None:
385-
continue
386-
probs, curves = probs_curves
416+
for name, (probs, curves) in self.scores:
387417
ind = min(np.searchsorted(probs, self.threshold),
388418
len(probs) - 1)
389419
text += f"<tr><th align='right'>{name}:</th>"
@@ -397,20 +427,28 @@ def threshold_change_done(self):
397427
self.apply()
398428

399429
def apply(self):
400-
info = self.Information
430+
self.Information.no_output.clear()
401431
wrapped = None
402-
problems = {}
403432
results = self.results
404433
if results is not None:
405-
problems = {
406-
info.no_output_multiple_folds: len(results.folds) > 1,
407-
info.no_output_no_models: results.models is None,
408-
info.no_output_multiple_selected:
409-
len(self.selected_classifiers) != 1,
410-
info.no_output_non_binary_class:
411-
self.score != 0
412-
and len(results.domain.class_var.values) != 2}
413-
if not any(problems.values()):
434+
problems = [
435+
msg for condition, msg in (
436+
(len(results.folds) > 1,
437+
"each training data sample produces a different model"),
438+
(results.models is None,
439+
"test results do not contain stored models - try testing on"
440+
"separate data or on training data"),
441+
(len(self.selected_classifiers) != 1,
442+
"select a single model - the widget can output only one"),
443+
(self.score != 0 and len(results.domain.class_var.values) != 2,
444+
"cannot calibrate non-binary classes"))
445+
if condition]
446+
if len(problems) == 1:
447+
self.Information.no_output(problems[0])
448+
elif problems:
449+
self.Information.no_output(
450+
"".join(f"\n - {problem}" for problem in problems))
451+
else:
414452
clsf_idx = self.selected_classifiers[0]
415453
model = results.models[0, clsf_idx]
416454
if self.score == 0:
@@ -424,9 +462,6 @@ def apply(self):
424462
wrapped = ThresholdClassifier(model, threshold)
425463

426464
self.Outputs.calibrated_model.send(wrapped)
427-
for info, shown in problems.items():
428-
if info.is_shown() != shown:
429-
info(shown=shown)
430465

431466
def send_report(self):
432467
if self.results is None:

0 commit comments

Comments
 (0)