Skip to content

Commit 26463de

Browse files
committed
Prediction: Output error
1 parent 480930e commit 26463de

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

Orange/widgets/evaluate/owpredictions.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,12 @@ def class_var(self):
265265
def is_discrete_class(self):
266266
return bool(self.class_var) and self.class_var.is_discrete
267267

268+
@property
269+
def shown_errors(self):
270+
return self.class_var and (
271+
self.show_probability_errors if self.is_discrete_class
272+
else self.show_reg_errors != NO_ERR)
273+
268274
@Inputs.predictors
269275
def set_predictor(self, index, predictor: Model):
270276
item = self.predictors[index]
@@ -331,15 +337,14 @@ def _reg_error_changed(self):
331337
self._update_prediction_delegate()
332338

333339
def _update_errors_visibility(self):
334-
shown = self.class_var and (
335-
self.show_probability_errors if self.is_discrete_class
336-
else self.show_reg_errors != NO_ERR)
340+
shown = self.shown_errors
337341
view = self.predictionsview
338342
for col, slot in enumerate(self.predictors):
339343
view.setColumnHidden(
340344
2 * col + 1,
341345
not shown or
342346
self.is_discrete_class is not slot.predictor.domain.has_discrete_class)
347+
self._commit_predictions()
343348

344349
def _set_class_values(self):
345350
self.class_values = []
@@ -814,12 +819,12 @@ def _commit_predictions(self):
814819

815820
newmetas = []
816821
newcolumns = []
817-
for slot in self._non_errored_predictors():
822+
for i, slot in enumerate(self._non_errored_predictors()):
818823
target = slot.predictor.domain.class_var
819824
if target and target.is_discrete:
820-
self._add_classification_out_columns(slot, newmetas, newcolumns)
825+
self._add_classification_out_columns(slot, newmetas, newcolumns, i)
821826
else:
822-
self._add_regression_out_columns(slot, newmetas, newcolumns)
827+
self._add_regression_out_columns(slot, newmetas, newcolumns, i)
823828

824829
attrs = list(self.data.domain.attributes)
825830
metas = list(self.data.domain.metas)
@@ -857,7 +862,7 @@ def _commit_predictions(self):
857862
predictions = predictions[datamodel.mapToSourceRows(...)]
858863
self.Outputs.predictions.send(predictions)
859864

860-
def _add_classification_out_columns(self, slot, newmetas, newcolumns):
865+
def _add_classification_out_columns(self, slot, newmetas, newcolumns, index):
861866
pred = slot.predictor
862867
name = pred.name
863868
values = pred.domain.class_var.values
@@ -877,10 +882,21 @@ def _add_classification_out_columns(self, slot, newmetas, newcolumns):
877882
else:
878883
newcolumns.append(numpy.zeros(probs.shape[0]))
879884

880-
@staticmethod
881-
def _add_regression_out_columns(slot, newmetas, newcolumns):
885+
# Column with error
886+
self._add_error_out_columns(slot, newmetas, newcolumns, index)
887+
888+
def _add_regression_out_columns(self, slot, newmetas, newcolumns, index):
882889
newmetas.append(ContinuousVariable(name=slot.predictor.name))
883890
newcolumns.append(slot.results.unmapped_predicted)
891+
self._add_error_out_columns(slot, newmetas, newcolumns, index)
892+
893+
def _add_error_out_columns(self, slot, newmetas, newcolumns, index):
894+
if self.shown_errors:
895+
name = f"{slot.predictor.name} (error)"
896+
newmetas.append(ContinuousVariable(name=name))
897+
err = self.predictionsview.model().errorColumn(index)
898+
err[err == 2] = numpy.nan
899+
newcolumns.append(err)
884900

885901
def send_report(self):
886902
def merge_data_with_predictions():

Orange/widgets/evaluate/tests/test_owpredictions.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def test_output_wrt_shown_probs_1(self):
869869
self.send_signal(widget.Inputs.predictors, bayes01, 0)
870870
self.send_signal(widget.Inputs.predictors, bayes12, 1)
871871
self.send_signal(widget.Inputs.predictors, bayes012, 2)
872+
widget.controls.show_probability_errors.setChecked(False)
872873

873874
for i, pred in enumerate(widget.predictors):
874875
p = pred.results.unmapped_probabilities
@@ -918,6 +919,7 @@ def test_output_wrt_shown_probs_2(self):
918919
self.send_signal(widget.Inputs.data, iris012)
919920
self.send_signal(widget.Inputs.predictors, bayes01, 0)
920921
self.send_signal(widget.Inputs.predictors, bayes012, 1)
922+
widget.controls.show_probability_errors.setChecked(False)
921923

922924
for i, pred in enumerate(widget.predictors):
923925
p = pred.results.unmapped_probabilities
@@ -968,7 +970,7 @@ def test_output_regression(self):
968970
MeanLearner()(self.housing), 1)
969971
out = self.get_output(widget.Outputs.predictions)
970972
np.testing.assert_equal(
971-
out.metas,
973+
out.metas[:, [0, 2]],
972974
np.hstack([pred.results.predicted.T for pred in widget.predictors]))
973975

974976
def test_classless(self):
@@ -1188,6 +1190,43 @@ def test_migrate_shown_scores(self):
11881190
self.widget.migrate_settings(settings, 1)
11891191
self.assertTrue(settings["score_table"]["show_score_hints"]["Sensitivity"])
11901192

1193+
def test_output_error_reg(self):
1194+
data = self.housing
1195+
lin_reg = LinearRegressionLearner()
1196+
self.send_signal(self.widget.Inputs.data, data)
1197+
self.send_signal(self.widget.Inputs.predictors, lin_reg(data), 0)
1198+
self.send_signal(self.widget.Inputs.predictors,
1199+
LinearRegressionLearner(fit_intercept=False)(data), 1)
1200+
pred = self.get_output(self.widget.Outputs.predictions)
1201+
1202+
names = ["", " (error)"]
1203+
names = [f"{n}{i}" for i in ("", " (1)") for n in names]
1204+
names = [f"{lin_reg.name}{x}" for x in names]
1205+
self.assertEqual(names, [m.name for m in pred.domain.metas])
1206+
self.assertAlmostEqual(pred.metas[0, 1], 6.0, 1)
1207+
self.assertAlmostEqual(pred.metas[0, 3], 5.1, 1)
1208+
1209+
def test_output_error_cls(self):
1210+
data = self.iris
1211+
log_reg = LogisticRegressionLearner()
1212+
self.send_signal(self.widget.Inputs.predictors, log_reg(data), 0)
1213+
self.send_signal(self.widget.Inputs.predictors,
1214+
LogisticRegressionLearner(penalty="l1")(data), 1)
1215+
with data.unlocked(data.Y):
1216+
data.Y[1] = np.nan
1217+
self.send_signal(self.widget.Inputs.data, data)
1218+
pred = self.get_output(self.widget.Outputs.predictions)
1219+
1220+
names = [""] + [f" ({v})" for v in
1221+
list(data.domain.class_var.values) + ["error"]]
1222+
names = [f"{n}{i}" for i in ("", " (1)") for n in names]
1223+
names = [f"{log_reg.name}{x}" for x in names]
1224+
self.assertEqual(names, [m.name for m in pred.domain.metas])
1225+
self.assertAlmostEqual(pred.metas[0, 4], 0.018, 3)
1226+
self.assertAlmostEqual(pred.metas[0, 9], 0.113, 3)
1227+
self.assertTrue(np.isnan(pred.metas[1, 4]))
1228+
self.assertTrue(np.isnan(pred.metas[1, 9]))
1229+
11911230

11921231
class SelectionModelTest(unittest.TestCase):
11931232
def setUp(self):

0 commit comments

Comments
 (0)