Skip to content

Commit 88f1048

Browse files
authored
Predictions: Output annotated table (#6718)
* Predictions: Output annotated table * Predictions: Output annotated table
1 parent 7bdc8e4 commit 88f1048

File tree

3 files changed

+136
-33
lines changed

3 files changed

+136
-33
lines changed

Orange/widgets/evaluate/owpredictions.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from orangecanvas.utils.localization import pl
2020
from orangewidget.utils.itemmodels import AbstractSortTableModel
21+
from orangewidget.utils.signals import LazyValue
2122

2223
import Orange
2324
from Orange.evaluation import Results
@@ -31,6 +32,8 @@
3132
from Orange.widgets.utils.widgetpreview import WidgetPreview
3233
from Orange.widgets.widget import OWWidget, Msg, Input, Output, MultiInput
3334
from Orange.widgets.utils.itemmodels import TableModel
35+
from Orange.widgets.utils.annotated_data import lazy_annotated_table, \
36+
domain_with_annotation_column, create_annotated_table
3437
from Orange.widgets.utils.sql import check_sql_input
3538
from Orange.widgets.utils.state_summary import format_summary_details
3639
from Orange.widgets.utils.colorpalettes import LimitedDiscretePalette
@@ -72,7 +75,9 @@ class Inputs:
7275
predictors = MultiInput("Predictors", Model, filter_none=True)
7376

7477
class Outputs:
75-
predictions = Output("Predictions", Orange.data.Table)
78+
selected_predictions = Output("Selected Predictions", Orange.data.Table,
79+
default=True, replaces=["Predictions"])
80+
annotated = Output("Predictions", Orange.data.Table)
7681
evaluation_results = Output("Evaluation Results", Results)
7782

7883
class Warning(OWWidget.Warning):
@@ -814,7 +819,8 @@ def _commit_evaluation_results(self):
814819

815820
def _commit_predictions(self):
816821
if not self.data:
817-
self.Outputs.predictions.send(None)
822+
self.Outputs.selected_predictions.send(None)
823+
self.Outputs.annotated.send(None)
818824
return
819825

820826
newmetas = []
@@ -850,17 +856,26 @@ def _commit_predictions(self):
850856
predmodel = self.predictionsview.model()
851857
assert datamodel is not None # because we have data
852858
assert self.selection_store is not None
853-
rows = numpy.array(list(self.selection_store.rows))
859+
rows = numpy.array(list(self.selection_store.rows), dtype=int)
854860
if rows.size:
861+
domain, _ = domain_with_annotation_column(predictions)
862+
annotated_data = LazyValue[Orange.data.Table](
863+
lambda: create_annotated_table(
864+
predictions, rows)[datamodel.mapToSourceRows(...)],
865+
length=len(predictions), domain=domain)
866+
855867
# Reorder rows as they are ordered in view
856868
shown_rows = datamodel.mapFromSourceRows(rows)
857869
rows = rows[numpy.argsort(shown_rows)]
858-
predictions = predictions[rows]
859-
elif datamodel.sortColumn() >= 0 \
860-
or predmodel is not None and predmodel.sortColumn() > 0:
861-
# No selection: output all, but in the shown order
862-
predictions = predictions[datamodel.mapToSourceRows(...)]
863-
self.Outputs.predictions.send(predictions)
870+
selected = predictions[rows]
871+
else:
872+
if datamodel.sortColumn() >= 0 \
873+
or predmodel is not None and predmodel.sortColumn() > 0:
874+
predictions = predictions[datamodel.mapToSourceRows(...)]
875+
selected = predictions
876+
annotated_data = lazy_annotated_table(predictions, rows)
877+
self.Outputs.selected_predictions.send(selected)
878+
self.Outputs.annotated.send(annotated_data)
864879

865880
def _add_classification_out_columns(self, slot, newmetas, newcolumns, index):
866881
pred = slot.predictor

Orange/widgets/evaluate/tests/test_owpredictions.py

Lines changed: 111 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from Orange.evaluation import Results
3737
from Orange.widgets.tests.utils import excepthook_catch, \
3838
possible_duplicate_table, simulate
39+
from Orange.widgets.utils.annotated_data import ANNOTATED_DATA_FEATURE_NAME
3940
from Orange.widgets.utils.colorpalettes import LimitedDiscretePalette
4041

4142

@@ -62,7 +63,7 @@ def test_nan_target_input(self):
6263
yvec = data.get_column(data.domain.class_var)
6364
self.send_signal(self.widget.Inputs.data, data)
6465
self.send_signal(self.widget.Inputs.predictors, ConstantLearner()(data), 1)
65-
pred = self.get_output(self.widget.Outputs.predictions)
66+
pred = self.get_output(self.widget.Outputs.selected_predictions)
6667
self.assertIsInstance(pred, Table)
6768
np.testing.assert_array_equal(
6869
yvec, pred.get_column(data.domain.class_var))
@@ -92,7 +93,7 @@ def test_no_values_target(self):
9293
test = Table(domain, np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]]),
9394
np.full((3, 1), np.nan))
9495
self.send_signal(self.widget.Inputs.data, test)
95-
pred = self.get_output(self.widget.Outputs.predictions)
96+
pred = self.get_output(self.widget.Outputs.selected_predictions)
9697
self.assertEqual(len(pred), len(test))
9798

9899
results = self.get_output(self.widget.Outputs.evaluation_results)
@@ -145,7 +146,7 @@ def test_no_class_on_test(self):
145146
no_class = titanic.transform(Domain(titanic.domain.attributes, None))
146147
self.send_signal(self.widget.Inputs.predictors, majority_titanic, 1)
147148
self.send_signal(self.widget.Inputs.data, no_class)
148-
out = self.get_output(self.widget.Outputs.predictions)
149+
out = self.get_output(self.widget.Outputs.selected_predictions)
149150
np.testing.assert_allclose(out.get_column("constant"), 0)
150151

151152
predmodel = self.widget.predictionsview.model()
@@ -500,7 +501,7 @@ def test_unique_output_domain(self):
500501
self.send_signal(self.widget.Inputs.data, data)
501502
self.send_signal(self.widget.Inputs.predictors, predictor)
502503

503-
output = self.get_output(self.widget.Outputs.predictions)
504+
output = self.get_output(self.widget.Outputs.selected_predictions)
504505
self.assertEqual(output.domain.metas[0].name, 'constant (1)')
505506

506507
def test_select(self):
@@ -515,6 +516,92 @@ def test_select(self):
515516
for index in self.widget.dataview.selectionModel().selectedIndexes()}
516517
self.assertEqual(sel, {(1, col) for col in range(5)})
517518

519+
def test_selection_output(self):
520+
log_reg_iris = LogisticRegressionLearner()(self.iris)
521+
self.send_signal(self.widget.Inputs.predictors, log_reg_iris)
522+
self.send_signal(self.widget.Inputs.data, self.iris)
523+
524+
selmodel = self.widget.dataview.selectionModel()
525+
pred_model = self.widget.predictionsview.model()
526+
527+
selmodel.select(self.widget.dataview.model().index(1, 0), QItemSelectionModel.Select)
528+
selmodel.select(self.widget.dataview.model().index(3, 0), QItemSelectionModel.Select)
529+
output = self.get_output(self.widget.Outputs.selected_predictions)
530+
self.assertEqual(len(output), 2)
531+
self.assertEqual(output[0], self.iris[1])
532+
self.assertEqual(output[1], self.iris[3])
533+
output = self.get_output(self.widget.Outputs.annotated)
534+
self.assertEqual(len(output), len(self.iris))
535+
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
536+
self.assertEqual(np.sum(col), 2)
537+
self.assertEqual(col[1], 1)
538+
self.assertEqual(col[3], 1)
539+
540+
pred_model.sort(0)
541+
output = self.get_output(self.widget.Outputs.selected_predictions)
542+
self.assertEqual(len(output), 2)
543+
self.assertEqual(output[0], self.iris[1])
544+
self.assertEqual(output[1], self.iris[3])
545+
ann_output = self.get_output(self.widget.Outputs.annotated)
546+
self.assertEqual(len(ann_output), len(self.iris))
547+
col = ann_output.get_column(ANNOTATED_DATA_FEATURE_NAME)
548+
self.assertEqual(np.sum(col), 2)
549+
np.testing.assert_array_equal(ann_output[col == 1].X, output.X)
550+
551+
pred_model.sort(0, Qt.DescendingOrder)
552+
output = self.get_output(self.widget.Outputs.selected_predictions)
553+
self.assertEqual(len(output), 2)
554+
self.assertEqual(output[0], self.iris[3])
555+
self.assertEqual(output[1], self.iris[1])
556+
ann_output = self.get_output(self.widget.Outputs.annotated)
557+
self.assertEqual(len(ann_output), len(self.iris))
558+
col = ann_output.get_column(ANNOTATED_DATA_FEATURE_NAME)
559+
self.assertEqual(np.sum(col), 2)
560+
np.testing.assert_array_equal(ann_output[col == 1].X, output.X)
561+
562+
def test_no_selection_output(self):
563+
log_reg_iris = LogisticRegressionLearner()(self.iris)
564+
self.send_signal(self.widget.Inputs.predictors, log_reg_iris)
565+
self.send_signal(self.widget.Inputs.data, self.iris)
566+
567+
data_model = self.widget.dataview.model()
568+
569+
output = self.get_output(self.widget.Outputs.selected_predictions)
570+
self.assertEqual(len(output), len(self.iris))
571+
output = self.get_output(self.widget.Outputs.annotated)
572+
self.assertEqual(len(output), len(self.iris))
573+
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
574+
self.assertFalse(np.any(col))
575+
576+
data_model.sort(2)
577+
col_name = data_model.headerData(2, Qt.Horizontal, Qt.DisplayRole) # "sepal width"
578+
output = self.get_output(self.widget.Outputs.selected_predictions)
579+
self.assertEqual(len(output), len(self.iris))
580+
col = output.get_column(col_name)
581+
self.assertTrue(np.all(col[1:] >= col[:-1]))
582+
583+
output = self.get_output(self.widget.Outputs.annotated)
584+
self.assertEqual(len(output), len(self.iris))
585+
col = output.get_column(col_name)
586+
self.assertTrue(np.all(col[1:] >= col[:-1]))
587+
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
588+
self.assertFalse(np.any(col))
589+
590+
data_model.sort(2, Qt.DescendingOrder)
591+
col_name = data_model.headerData(2, Qt.Horizontal, Qt.DisplayRole) # "sepal width"
592+
output = self.get_output(self.widget.Outputs.selected_predictions)
593+
self.assertEqual(len(output), len(self.iris))
594+
col = output.get_column(col_name)
595+
self.assertTrue(np.all(col[1:] <= col[:-1]))
596+
597+
output = self.get_output(self.widget.Outputs.annotated)
598+
self.assertEqual(len(output), len(self.iris))
599+
col = output.get_column(col_name)
600+
self.assertTrue(np.all(col[1:] <= col[:-1]))
601+
col = output.get_column(ANNOTATED_DATA_FEATURE_NAME)
602+
self.assertFalse(np.any(col))
603+
604+
518605
def test_select_data_first(self):
519606
log_reg_iris = LogisticRegressionLearner()(self.iris)
520607
self.send_signal(self.widget.Inputs.data, self.iris)
@@ -537,7 +624,7 @@ def test_selection_in_setting(self):
537624
for index in widget.dataview.selectionModel().selectedIndexes()}
538625
self.assertEqual(sel, {(row, col)
539626
for row in [1, 3, 4] for col in range(5)})
540-
out = self.get_output(widget.Outputs.predictions)
627+
out = self.get_output(widget.Outputs.selected_predictions)
541628
exp = self.iris[np.array([1, 3, 4])]
542629
np.testing.assert_equal(out.X, exp.X)
543630

@@ -883,32 +970,32 @@ def test_output_wrt_shown_probs_1(self):
883970

884971
widget.shown_probs = widget.NO_PROBS
885972
widget._commit_predictions()
886-
out = self.get_output(widget.Outputs.predictions)
973+
out = self.get_output(widget.Outputs.selected_predictions)
887974
self.assertEqual(list(out.metas[0]), [0, 1, 2])
888975

889976
widget.shown_probs = widget.DATA_PROBS
890977
widget._commit_predictions()
891-
out = self.get_output(widget.Outputs.predictions)
978+
out = self.get_output(widget.Outputs.selected_predictions)
892979
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 0, 110, 2, 210, 211])
893980

894981
widget.shown_probs = widget.MODEL_PROBS
895982
widget._commit_predictions()
896-
out = self.get_output(widget.Outputs.predictions)
983+
out = self.get_output(widget.Outputs.selected_predictions)
897984
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 2, 210, 211, 212])
898985

899986
widget.shown_probs = widget.BOTH_PROBS
900987
widget._commit_predictions()
901-
out = self.get_output(widget.Outputs.predictions)
988+
out = self.get_output(widget.Outputs.selected_predictions)
902989
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 2, 210, 211])
903990

904991
widget.shown_probs = widget.BOTH_PROBS + 1
905992
widget._commit_predictions()
906-
out = self.get_output(widget.Outputs.predictions)
993+
out = self.get_output(widget.Outputs.selected_predictions)
907994
self.assertEqual(list(out.metas[0]), [0, 10, 1, 0, 2, 210])
908995

909996
widget.shown_probs = widget.BOTH_PROBS + 2
910997
widget._commit_predictions()
911-
out = self.get_output(widget.Outputs.predictions)
998+
out = self.get_output(widget.Outputs.selected_predictions)
912999
self.assertEqual(list(out.metas[0]), [0, 11, 1, 110, 2, 211])
9131000

9141001
def test_output_wrt_shown_probs_2(self):
@@ -933,37 +1020,37 @@ def test_output_wrt_shown_probs_2(self):
9331020

9341021
widget.shown_probs = widget.NO_PROBS
9351022
widget._commit_predictions()
936-
out = self.get_output(widget.Outputs.predictions)
1023+
out = self.get_output(widget.Outputs.selected_predictions)
9371024
self.assertEqual(list(out.metas[0]), [0, 1])
9381025

9391026
widget.shown_probs = widget.DATA_PROBS
9401027
widget._commit_predictions()
941-
out = self.get_output(widget.Outputs.predictions)
1028+
out = self.get_output(widget.Outputs.selected_predictions)
9421029
self.assertEqual(list(out.metas[0]), [0, 10, 11, 0, 1, 110, 111, 112])
9431030

9441031
widget.shown_probs = widget.MODEL_PROBS
9451032
widget._commit_predictions()
946-
out = self.get_output(widget.Outputs.predictions)
1033+
out = self.get_output(widget.Outputs.selected_predictions)
9471034
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 112])
9481035

9491036
widget.shown_probs = widget.BOTH_PROBS
9501037
widget._commit_predictions()
951-
out = self.get_output(widget.Outputs.predictions)
1038+
out = self.get_output(widget.Outputs.selected_predictions)
9521039
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 112])
9531040

9541041
widget.shown_probs = widget.BOTH_PROBS + 1
9551042
widget._commit_predictions()
956-
out = self.get_output(widget.Outputs.predictions)
1043+
out = self.get_output(widget.Outputs.selected_predictions)
9571044
self.assertEqual(list(out.metas[0]), [0, 10, 1, 110])
9581045

9591046
widget.shown_probs = widget.BOTH_PROBS + 2
9601047
widget._commit_predictions()
961-
out = self.get_output(widget.Outputs.predictions)
1048+
out = self.get_output(widget.Outputs.selected_predictions)
9621049
self.assertEqual(list(out.metas[0]), [0, 11, 1, 111])
9631050

9641051
widget.shown_probs = widget.BOTH_PROBS + 3
9651052
widget._commit_predictions()
966-
out = self.get_output(widget.Outputs.predictions)
1053+
out = self.get_output(widget.Outputs.selected_predictions)
9671054
self.assertEqual(list(out.metas[0]), [0, 0, 1, 112])
9681055

9691056
def test_output_regression(self):
@@ -973,7 +1060,7 @@ def test_output_regression(self):
9731060
LinearRegressionLearner()(self.housing), 0)
9741061
self.send_signal(widget.Inputs.predictors,
9751062
MeanLearner()(self.housing), 1)
976-
out = self.get_output(widget.Outputs.predictions)
1063+
out = self.get_output(widget.Outputs.selected_predictions)
9771064
np.testing.assert_equal(
9781065
out.metas[:, [0, 2]],
9791066
np.hstack([pred.results.predicted.T for pred in widget.predictors]))
@@ -1001,12 +1088,12 @@ def test_classless(self):
10011088

10021089
widget.shown_probs = widget.NO_PROBS
10031090
widget._commit_predictions()
1004-
out = self.get_output(widget.Outputs.predictions)
1091+
out = self.get_output(widget.Outputs.selected_predictions)
10051092
self.assertEqual(list(out.metas[0]), [0, 1, 2])
10061093

10071094
widget.shown_probs = widget.MODEL_PROBS
10081095
widget._commit_predictions()
1009-
out = self.get_output(widget.Outputs.predictions)
1096+
out = self.get_output(widget.Outputs.selected_predictions)
10101097
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 2, 210, 211, 212])
10111098

10121099
@patch("Orange.widgets.evaluate.owpredictions.usable_scorers",
@@ -1047,7 +1134,7 @@ def test_multi_target_input(self):
10471134

10481135
self.send_signal(widget.Inputs.data, data)
10491136
self.send_signal(widget.Inputs.predictors, mock_model, 1)
1050-
pred = self.get_output(widget.Outputs.predictions)
1137+
pred = self.get_output(widget.Outputs.selected_predictions)
10511138
self.assertIsInstance(pred, Table)
10521139

10531140
def test_error_controls_visibility(self):
@@ -1202,7 +1289,7 @@ def test_output_error_reg(self):
12021289
self.send_signal(self.widget.Inputs.predictors, lin_reg(data), 0)
12031290
self.send_signal(self.widget.Inputs.predictors,
12041291
LinearRegressionLearner(fit_intercept=False)(data), 1)
1205-
pred = self.get_output(self.widget.Outputs.predictions)
1292+
pred = self.get_output(self.widget.Outputs.selected_predictions)
12061293

12071294
names = ["", " (error)"]
12081295
names = [f"{n}{i}" for i in ("", " (1)") for n in names]
@@ -1220,7 +1307,7 @@ def test_output_error_cls(self):
12201307
with data.unlocked(data.Y):
12211308
data.Y[1] = np.nan
12221309
self.send_signal(self.widget.Inputs.data, data)
1223-
pred = self.get_output(self.widget.Outputs.predictions)
1310+
pred = self.get_output(self.widget.Outputs.selected_predictions)
12241311

12251312
names = [""] + [f" ({v})" for v in
12261313
list(data.domain.class_var.values) + ["error"]]

i18n/si/msgs.jaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8940,6 +8940,7 @@ widgets/evaluate/owpredictions.py:
89408940
Data: Podatki
89418941
Predictors: Modeli
89428942
class `Outputs`:
8943+
Selected Predictions: Izbrane napovedi
89438944
Predictions: Napovedi
89448945
Evaluation Results: Rezultati vrednotenja
89458946
class `Warning`:

0 commit comments

Comments
 (0)