diff --git a/Orange/widgets/data/owrank.py b/Orange/widgets/data/owrank.py index c7fdad641f3..4e202a6e640 100644 --- a/Orange/widgets/data/owrank.py +++ b/Orange/widgets/data/owrank.py @@ -147,6 +147,13 @@ def resetSorting(self, yes_reset=False): if yes_reset: super().resetSorting() + def _argsortData(self, data, order): + """Always sort NaNs last""" + indices = np.argsort(data, kind='mergesort') + if order == Qt.DescendingOrder: + return np.roll(indices[::-1], -np.isnan(data).sum()) + return indices + class OWRank(OWWidget): name = "Rank" @@ -417,13 +424,14 @@ def updateScores(self): self.ranksModel.wrap(model_array.tolist()) self.ranksModel.setHorizontalHeaderLabels(('#',) + labels) - self.ranksView.setColumnWidth(0, 30) + self.ranksView.setColumnWidth(0, 40) # Re-apply sort try: sort_column, sort_order = self.sorting if sort_column < len(labels): self.ranksModel.sort(sort_column + 1, sort_order) # +1 for '#' (discrete count) column + self.ranksView.horizontalHeader().setSortIndicator(sort_column + 1, sort_order) except ValueError: pass @@ -541,7 +549,7 @@ def migrate_settings(cls, settings, version): # Saved selected_rows will likely be incorrect if version is None or version < 2: column, order = 0, Qt.DescendingOrder - headerState = settings.pop("headerState") + headerState = settings.pop("headerState", None) # Lacking knowledge of last problemType, use discrete ranks view's ordering if isinstance(headerState, (tuple, list)): diff --git a/Orange/widgets/data/tests/test_owrank.py b/Orange/widgets/data/tests/test_owrank.py index 72fcdb1f9d2..5b905f36857 100644 --- a/Orange/widgets/data/tests/test_owrank.py +++ b/Orange/widgets/data/tests/test_owrank.py @@ -231,6 +231,24 @@ def test_scores_sorting(self): order2 = self.widget.ranksModel.mapToSourceRows(...).tolist() self.assertNotEqual(order1, order2) + def test_scores_nan_sorting(self): + """Check NaNs are sorted last""" + data = self.iris.copy() + data.get_column_view('petal length')[0][:] = np.nan + self.send_signal(self.widget.Inputs.data, data) + + # Assert last row is all nan + for order in (Qt.AscendingOrder, + Qt.DescendingOrder): + self.widget.ranksView.horizontalHeader().setSortIndicator(1, order) + last_row = self.widget.ranksModel[self.widget.ranksModel.mapToSourceRows(...)[-1]] + np.testing.assert_array_equal(last_row, np.repeat(np.nan, 3)) + + def test_default_sort_indicator(self): + self.send_signal(self.widget.Inputs.data, self.iris) + self.assertNotEqual( + 0, self.widget.ranksView.horizontalHeader().sortIndicatorSection()) + def test_data_which_make_scorer_nan(self): """ Tests if widget crashes due to too high (Infinite) calculated values. diff --git a/Orange/widgets/tests/test_itemmodels.py b/Orange/widgets/tests/test_itemmodels.py index 9f42a4b4da7..6ca635e33e6 100644 --- a/Orange/widgets/tests/test_itemmodels.py +++ b/Orange/widgets/tests/test_itemmodels.py @@ -7,7 +7,7 @@ from Orange.data import Domain, ContinuousVariable from Orange.widgets.utils.itemmodels import \ - PyTableModel, PyListModel, DomainModel, _argsort + AbstractSortTableModel, PyTableModel, PyListModel, DomainModel, _argsort class TestArgsort(TestCase): @@ -120,6 +120,20 @@ def test_other_roles(self): Qt.TextAlignmentRole)) +class TestAbstractSortTableModel(TestCase): + def setUp(self): + assert issubclass(PyTableModel, AbstractSortTableModel) + self.model = PyTableModel([[1, 4], + [2, 3]]) + + def test_sorting(self): + self.model.sort(1, Qt.AscendingOrder) + self.assertSequenceEqual(self.model.mapToSourceRows(...).tolist(), [1, 0]) + + self.model.sort(1, Qt.DescendingOrder) + self.assertSequenceEqual(self.model.mapToSourceRows(...).tolist(), [0, 1]) + + class TestPyListModel(TestCase): @classmethod def setUpClass(cls): diff --git a/Orange/widgets/utils/itemmodels.py b/Orange/widgets/utils/itemmodels.py index 45d830319b3..d69d755dbae 100644 --- a/Orange/widgets/utils/itemmodels.py +++ b/Orange/widgets/utils/itemmodels.py @@ -157,6 +157,16 @@ def resetSorting(self): """Invalidates the current sorting""" return self.sort(-1) + def _argsortData(self, data: numpy.ndarray, order): + """ + Return indices of sorted data. May be reimplemented to handle + sorting in a certain way, e.g. to sort NaN values last. + """ + indices = numpy.argsort(data, kind="mergesort") + if order == Qt.DescendingOrder: + indices = indices[::-1] + return indices + def sort(self, column: int, order: Qt.SortOrder=Qt.AscendingOrder): """ Sort the data by `column` into `order`. @@ -184,17 +194,12 @@ def sort(self, column: int, order: Qt.SortOrder=Qt.AscendingOrder): indices = None if column >= 0: data = numpy.asarray(self._sortColumnData(column)) - if data is not None: - if data.dtype == object: - data = data.astype(str) - indices = numpy.argsort(data, kind="mergesort") - else: - indices = numpy.arange(self.rowCount()) - - if order == Qt.DescendingOrder: - indices = indices[::-1] + if data is None: + data = numpy.arange(self.rowCount()) + elif data.dtype == object: + data = data.astype(str) - indices = self.mapToSourceRows(indices) + indices = self.mapToSourceRows(self._argsortData(data, order)) if indices is not None: self.__sortInd = indices