Skip to content

Commit bf6edbe

Browse files
authored
Merge pull request #3951 from janezd/test-score-sorting
[FIX] Test and Score: Sort numerically, not alphabetically
2 parents 40916e2 + 0887a2a commit bf6edbe

File tree

4 files changed

+92
-10
lines changed

4 files changed

+92
-10
lines changed

Orange/widgets/evaluate/owtestlearners.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def update_stats_model(self):
533533
for stat, scorer in zip(stats, self.scorers):
534534
item = QStandardItem()
535535
if stat.success:
536-
item.setText("{:.3f}".format(stat.value[0]))
536+
item.setData(float(stat.value[0]), Qt.DisplayRole)
537537
else:
538538
item.setToolTip(str(stat.exception))
539539
if scorer.name in self.score_table.shown_scores:

Orange/widgets/evaluate/tests/test_owtestlearners.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,10 @@ def __call__(self, data):
277277
# Ensure that the click on header caused an ascending sort
278278
# Ascending sort means that wrong model should be listed first
279279
self.assertEqual(header.sortIndicatorOrder(), Qt.AscendingOrder)
280-
self.assertEqual(view.model().item(0, 0).text(), "VersicolorLearner")
280+
self.assertEqual(view.model().index(0, 0).data(), "VersicolorLearner")
281281

282282
self.send_signal(self.widget.Inputs.test_data, versicolor, wait=5000)
283-
self.assertEqual(view.model().item(0, 0).text(), "SetosaLearner")
283+
self.assertEqual(view.model().index(0, 0).data(), "SetosaLearner")
284284

285285
self.widget.hide()
286286

@@ -365,10 +365,11 @@ def test_scores_log_reg_advanced(self):
365365
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
366366
)
367367

368-
self.assertTupleEqual(self._test_scores(
369-
table_train, table_test, LogisticRegressionLearner(),
370-
OWTestLearners.TestOnTest, None),
371-
(0.667, 0.8, 0.8, 0.867, 0.8))
368+
np.testing.assert_almost_equal(
369+
self._test_scores(table_train, table_test,
370+
LogisticRegressionLearner(),
371+
OWTestLearners.TestOnTest, None),
372+
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))
372373

373374
def test_scores_cross_validation(self):
374375
"""

Orange/widgets/evaluate/tests/test_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import unittest
44
import collections
55

6+
import numpy as np
7+
68
from AnyQt.QtWidgets import QMenu
7-
from AnyQt.QtCore import QPoint
9+
from AnyQt.QtGui import QStandardItem
10+
from AnyQt.QtCore import QPoint, Qt
811

912
from Orange.widgets.evaluate.utils import ScoreTable
1013
from Orange.widgets.tests.base import GuiTest
@@ -70,5 +73,48 @@ def test_update_shown_columns(self):
7073
not header.isSectionHidden(i),
7174
msg="error in section {}({})".format(i, name))
7275

76+
def test_sorting(self):
77+
def order(n=5):
78+
return "".join(model.index(i, 0).data() for i in range(n))
79+
80+
score_table = ScoreTable(None)
81+
82+
data = [
83+
["D", 11.0, 15.3],
84+
["C", 5.0, -15.4],
85+
["b", 20.0, np.nan],
86+
["A", None, None],
87+
["E", "", 0.0]
88+
]
89+
for data_row in data:
90+
row = []
91+
for x in data_row:
92+
item = QStandardItem()
93+
if x is not None:
94+
item.setData(x, Qt.DisplayRole)
95+
row.append(item)
96+
score_table.model.appendRow(row)
97+
98+
model = score_table.view.model()
99+
100+
model.sort(0, Qt.AscendingOrder)
101+
self.assertEqual(order(), "AbCDE")
102+
103+
model.sort(0, Qt.DescendingOrder)
104+
self.assertEqual(order(), "EDCbA")
105+
106+
model.sort(1, Qt.AscendingOrder)
107+
self.assertEqual(order(3), "CDb")
108+
109+
model.sort(1, Qt.DescendingOrder)
110+
self.assertEqual(order(3), "bDC")
111+
112+
model.sort(2, Qt.AscendingOrder)
113+
self.assertEqual(order(3), "CED")
114+
115+
model.sort(2, Qt.DescendingOrder)
116+
self.assertEqual(order(3), "DEC")
117+
118+
73119
if __name__ == "__main__":
74120
unittest.main()

Orange/widgets/evaluate/utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from AnyQt.QtWidgets import QHeaderView, QStyledItemDelegate, QMenu
88
from AnyQt.QtGui import QStandardItemModel, QStandardItem
9-
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal
9+
from AnyQt.QtCore import Qt, QSize, QObject, pyqtSignal as Signal, \
10+
QSortFilterProxyModel
1011
from sklearn.exceptions import UndefinedMetricWarning
1112

1213
from Orange.data import Variable, DiscreteVariable, ContinuousVariable
@@ -98,6 +99,32 @@ def thunked():
9899
return thunked
99100

100101

102+
class ScoreModel(QSortFilterProxyModel):
103+
def lessThan(self, left, right):
104+
def is_bad(x):
105+
return not isinstance(x, (int, float, str)) \
106+
or isinstance(x, float) and np.isnan(x)
107+
108+
left = left.data()
109+
right = right.data()
110+
is_ascending = self.sortOrder() == Qt.AscendingOrder
111+
112+
# bad entries go below; if both are bad, left remains above
113+
if is_bad(left) or is_bad(right):
114+
return is_bad(right) == is_ascending
115+
116+
# for data of different types, numbers are at the top
117+
if type(left) is not type(right):
118+
return isinstance(left, float) == is_ascending
119+
120+
# case insensitive comparison for strings
121+
if isinstance(left, str):
122+
return left.upper() < right.upper()
123+
124+
# otherwise, compare numbers
125+
return left < right
126+
127+
101128
class ScoreTable(OWComponent, QObject):
102129
shown_scores = \
103130
Setting(set(chain(*BUILTIN_SCORERS_ORDER.values())))
@@ -109,6 +136,12 @@ def sizeHint(self, *args):
109136
size = super().sizeHint(*args)
110137
return QSize(size.width(), size.height() + 6)
111138

139+
def displayText(self, value, locale):
140+
if isinstance(value, float):
141+
return f"{value:.3f}"
142+
else:
143+
return super().displayText(value, locale)
144+
112145
def __init__(self, master):
113146
QObject.__init__(self)
114147
OWComponent.__init__(self, master)
@@ -125,7 +158,9 @@ def __init__(self, master):
125158

126159
self.model = QStandardItemModel(master)
127160
self.model.setHorizontalHeaderLabels(["Method"])
128-
self.view.setModel(self.model)
161+
self.sorted_model = ScoreModel()
162+
self.sorted_model.setSourceModel(self.model)
163+
self.view.setModel(self.sorted_model)
129164
self.view.setItemDelegate(self.ItemDelegate())
130165

131166
def _column_names(self):

0 commit comments

Comments
 (0)