Skip to content

Commit 9adddfa

Browse files
committed
Rank widget concurrent mixin
1 parent 62e5c69 commit 9adddfa

File tree

2 files changed

+104
-36
lines changed

2 files changed

+104
-36
lines changed

Orange/widgets/data/owrank.py

Lines changed: 73 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
1-
import warnings
2-
from collections import namedtuple, OrderedDict
31
import logging
2+
import warnings
3+
from collections import OrderedDict, namedtuple
44
from functools import partial
55
from itertools import chain
6+
from types import SimpleNamespace
7+
from typing import Callable, List, Any
68

79
import numpy as np
8-
from scipy.sparse import issparse
9-
10+
from AnyQt.QtCore import (
11+
QItemSelection, QItemSelectionModel, QItemSelectionRange, Qt
12+
)
1013
from AnyQt.QtGui import QFontMetrics
1114
from AnyQt.QtWidgets import (
12-
QTableView, QRadioButton, QButtonGroup, QGridLayout,
13-
QStackedWidget, QHeaderView, QCheckBox, QItemDelegate,
14-
)
15-
from AnyQt.QtCore import (
16-
Qt, QItemSelection, QItemSelectionRange, QItemSelectionModel,
15+
QButtonGroup, QCheckBox, QGridLayout, QHeaderView, QItemDelegate,
16+
QRadioButton, QStackedWidget, QTableView
1717
)
18-
1918
from orangewidget.settings import IncompatibleContext
20-
from Orange.data import (Table, Domain, ContinuousVariable, DiscreteVariable,
21-
StringVariable)
19+
from scipy.sparse import issparse
20+
21+
from Orange.data import (
22+
ContinuousVariable, DiscreteVariable, Domain, StringVariable, Table
23+
)
2224
from Orange.data.util import get_unique_names_duplicates
2325
from Orange.misc.cache import memoize_method
2426
from Orange.preprocess import score
25-
from Orange.widgets import report
26-
from Orange.widgets import gui
27-
from Orange.widgets.settings import (DomainContextHandler, Setting,
28-
ContextSetting)
27+
from Orange.widgets import gui, report
28+
from Orange.widgets.settings import (
29+
ContextSetting, DomainContextHandler, Setting
30+
)
31+
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
2932
from Orange.widgets.utils.itemmodels import PyTableModel
3033
from Orange.widgets.utils.sql import check_sql_input
31-
from Orange.widgets.utils.widgetpreview import WidgetPreview
3234
from Orange.widgets.utils.state_summary import format_summary_details
33-
from Orange.widgets.widget import (
34-
OWWidget, Msg, Input, Output, AttributeList
35-
)
36-
35+
from Orange.widgets.utils.widgetpreview import WidgetPreview
36+
from Orange.widgets.widget import AttributeList, Input, Msg, Output, OWWidget
3737

3838
log = logging.getLogger(__name__)
3939

@@ -167,7 +167,12 @@ def _argsortData(self, data, order):
167167
return indices
168168

169169

170-
class OWRank(OWWidget):
170+
class Results(SimpleNamespace):
171+
method_scores = None
172+
scorer_scores = None
173+
174+
175+
class OWRank(OWWidget, ConcurrentWidgetMixin):
171176
name = "Rank"
172177
description = "Rank and filter data features by their relevance."
173178
icon = "icons/Rank.svg"
@@ -213,9 +218,9 @@ class Warning(OWWidget.Warning):
213218
renamed_variables = Msg(
214219
"Variables with duplicated names have been renamed.")
215220

216-
217221
def __init__(self):
218-
super().__init__()
222+
OWWidget.__init__(self)
223+
ConcurrentWidgetMixin.__init__(self)
219224
self.scorers = OrderedDict()
220225
self.out_domain_desc = None
221226
self.data = None
@@ -226,7 +231,6 @@ def __init__(self):
226231
if method.is_default}
227232

228233
# GUI
229-
230234
self.ranksModel = model = TableModel(parent=self) # type: TableModel
231235
self.ranksView = view = TableView(self) # type: TableView
232236
self.mainArea.layout().addWidget(view)
@@ -360,7 +364,7 @@ def set_data(self, data):
360364

361365
def handleNewSignals(self):
362366
self.setStatusMessage('Running')
363-
self.updateScores()
367+
self.update_scores()
364368
self.setStatusMessage('')
365369
self.on_select()
366370

@@ -413,7 +417,29 @@ def get_scorer_scores(self, scorer):
413417
for i in range(1, 1 + scores.shape[1])))
414418
return scores, labels
415419

416-
def updateScores(self):
420+
def run(
421+
self,
422+
methods: List[ScoreMeta],
423+
scorers: List[ScoreMeta],
424+
state: TaskState
425+
) -> Results:
426+
progress_steps = iter(np.linspace(0, 100, len(methods) + len(scorers)))
427+
428+
def call_with_cb(get_scores: Callable, method: ScoreMeta):
429+
scores = get_scores(method)
430+
state.set_progress_value(next(progress_steps))
431+
return scores
432+
433+
method_scores = tuple(
434+
(call_with_cb(self.get_method_scores, method), method.shortname)
435+
for method in methods
436+
)
437+
scorer_scores = tuple(
438+
call_with_cb(self.get_scorer_scores, scorer) for scorer in scorers
439+
)
440+
return Results(method_scores=method_scores, scorer_scores=scorer_scores)
441+
442+
def update_scores(self):
417443
if self.data is None:
418444
self.ranksModel.clear()
419445
self.Outputs.scores.send(None)
@@ -434,16 +460,21 @@ def updateScores(self):
434460
else:
435461
self.Error.inadequate_learner(scorer.name, scorer.learner_adequacy_err_msg)
436462

437-
method_scores = tuple(self.get_method_scores(method)
438-
for method in methods)
463+
self.start(
464+
self.run,
465+
methods,
466+
scorers,
467+
)
439468

440-
scorer_scores, scorer_labels = (), ()
441-
if scorers:
442-
scorer_scores, scorer_labels = zip(*(self.get_scorer_scores(scorer)
443-
for scorer in scorers))
444-
scorer_labels = tuple(chain.from_iterable(scorer_labels))
469+
def on_done(self, result: Results) -> None:
470+
method_scores, method_labels = (
471+
zip(*result.method_scores) if result.method_scores else ((), ())
472+
)
473+
scorer_scores, scorer_labels = (
474+
zip(*result.scorer_scores) if result.scorer_scores else ((), ())
475+
)
445476

446-
labels = tuple(method.shortname for method in methods) + scorer_labels
477+
labels = method_labels + tuple(chain.from_iterable(scorer_labels))
447478
model_array = np.column_stack(
448479
([len(a.values) if a.is_discrete else np.nan
449480
for a in self.data.domain.attributes],) +
@@ -470,6 +501,12 @@ def updateScores(self):
470501
self.autoSelection()
471502
self.Outputs.scores.send(self.create_scores_table(labels))
472503

504+
def on_exception(self, ex: Exception):
505+
raise ex
506+
507+
def on_partial_result(self, result: Any) -> None:
508+
pass
509+
473510
def on_select(self):
474511
# Save indices of attributes in the original, unsorted domain
475512
selected_rows = self.ranksView.selectionModel().selectedRows(0)
@@ -529,7 +566,7 @@ def methodSelectionChanged(self, state, method_name):
529566
elif method_name in self.selected_methods:
530567
self.selected_methods.remove(method_name)
531568

532-
self.updateScores()
569+
self.update_scores()
533570

534571
def send_report(self):
535572
if not self.data:

Orange/widgets/data/tests/test_owrank.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_input_scorer(self):
5151
"""Check widget's scorer with scorer on the input"""
5252
self.assertEqual(self.widget.scorers, {})
5353
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 1)
54+
self.wait_until_finished()
5455
value = self.widget.scorers[1]
5556
self.assertEqual(self.log_reg, value.scorer)
5657
self.assertIsInstance(value.scorer, Scorer)
@@ -72,6 +73,7 @@ def test_input_scorer_fitter(self):
7273
heart_disease):
7374
with self.subTest(data=data.name):
7475
self.send_signal('Data', data)
76+
self.wait_until_finished()
7577
scores = [model.data(model.index(row, model.columnCount() - 1))
7678
for row in range(model.rowCount())]
7779
self.assertEqual(len(scores), len(data.domain.attributes))
@@ -94,6 +96,7 @@ def test_input_scorer_disconnect(self):
9496
def test_output_data(self):
9597
"""Check data on the output after apply"""
9698
self.send_signal(self.widget.Inputs.data, self.iris)
99+
self.wait_until_finished()
97100
output = self.get_output(self.widget.Outputs.reduced_data)
98101
self.assertIsInstance(output, Table)
99102
self.assertEqual(len(output.X), len(self.iris))
@@ -104,6 +107,7 @@ def test_output_data(self):
104107
def test_output_scores(self):
105108
"""Check scores on the output after apply"""
106109
self.send_signal(self.widget.Inputs.data, self.iris)
110+
self.wait_until_finished()
107111
output = self.get_output(self.widget.Outputs.scores)
108112
self.assertIsInstance(output, Table)
109113
self.assertEqual(output.X.shape, (len(self.iris.domain.attributes), 2))
@@ -114,12 +118,14 @@ def test_output_scores_with_scorer(self):
114118
"""Check scores on the output after apply with scorer on the input"""
115119
self.send_signal(self.widget.Inputs.data, self.iris)
116120
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 1)
121+
self.wait_until_finished()
117122
output = self.get_output(self.widget.Outputs.scores)
118123
self.assertIsInstance(output, Table)
119124
self.assertEqual(output.X.shape, (len(self.iris.domain.attributes), 5))
120125

121126
def test_output_features(self):
122127
self.send_signal(self.widget.Inputs.data, self.iris)
128+
self.wait_until_finished()
123129
output = self.get_output(self.widget.Outputs.features)
124130
self.assertIsInstance(output, AttributeList)
125131
self.send_signal(self.widget.Inputs.data, None)
@@ -161,8 +167,10 @@ def test_cls_scorer_reg_data(self):
161167
"""Check scores on the output with inadequate scorer"""
162168
self.send_signal(self.widget.Inputs.data, self.housing)
163169
self.send_signal(self.widget.Inputs.scorer, self.pca, 1)
170+
self.wait_until_finished()
164171
with patch("Orange.widgets.data.owrank.log.error") as log:
165172
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 2)
173+
self.wait_until_finished()
166174
log.assert_called()
167175
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
168176
(len(self.housing.domain.attributes), 16))
@@ -171,8 +179,10 @@ def test_reg_scorer_cls_data(self):
171179
"""Check scores on the output with inadequate scorer"""
172180
self.send_signal(self.widget.Inputs.data, self.iris)
173181
self.send_signal(self.widget.Inputs.scorer, self.pca, 1)
182+
self.wait_until_finished()
174183
with patch("Orange.widgets.data.owrank.log.error") as log:
175184
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 2)
185+
self.wait_until_finished()
176186
log.assert_called()
177187
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
178188
(len(self.iris.domain.attributes), 7))
@@ -181,6 +191,7 @@ def test_scores_updates_cls(self):
181191
"""Check arbitrary workflow with classification data"""
182192
self.send_signal(self.widget.Inputs.data, self.iris)
183193
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 1)
194+
self.wait_until_finished()
184195
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
185196
(len(self.iris.domain.attributes), 5))
186197
self._get_checkbox('Gini').setChecked(False)
@@ -190,16 +201,20 @@ def test_scores_updates_cls(self):
190201
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
191202
(len(self.iris.domain.attributes), 5))
192203
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 2)
204+
self.wait_until_finished()
193205
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
194206
(len(self.iris.domain.attributes), 8))
195207
self.send_signal(self.widget.Inputs.scorer, None, 1)
208+
self.wait_until_finished()
196209
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
197210
(len(self.iris.domain.attributes), 5))
198211
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 1)
212+
self.wait_until_finished()
199213
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
200214
(len(self.iris.domain.attributes), 8))
201215
with patch("Orange.widgets.data.owrank.log.error") as log:
202216
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 3)
217+
self.wait_until_finished()
203218
log.assert_called()
204219
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
205220
(len(self.iris.domain.attributes), 9))
@@ -208,6 +223,7 @@ def test_scores_updates_reg(self):
208223
"""Check arbitrary workflow with regression data"""
209224
self.send_signal(self.widget.Inputs.data, self.housing)
210225
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 1)
226+
self.wait_until_finished()
211227
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
212228
(len(self.housing.domain.attributes), 3))
213229

@@ -220,10 +236,12 @@ def test_scores_updates_reg(self):
220236
(len(self.housing.domain.attributes), 3))
221237

222238
self.send_signal(self.widget.Inputs.scorer, None, 1)
239+
self.wait_until_finished()
223240
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
224241
(len(self.housing.domain.attributes), 2))
225242

226243
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 1)
244+
self.wait_until_finished()
227245
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
228246
(len(self.housing.domain.attributes), 3))
229247

@@ -232,20 +250,24 @@ def test_scores_updates_no_class(self):
232250
data = Table.from_table(Domain(self.iris.domain.variables), self.iris)
233251
self.assertIsNone(data.domain.class_var)
234252
self.send_signal(self.widget.Inputs.data, data)
253+
self.wait_until_finished()
235254
self.assertIsNone(self.get_output(self.widget.Outputs.scores))
236255

237256
with patch("Orange.widgets.data.owrank.log.error") as log:
238257
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 1)
258+
self.wait_until_finished()
239259
log.assert_called()
240260
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
241261
(len(self.iris.domain.variables), 1))
242262

243263
self.send_signal(self.widget.Inputs.scorer, self.pca, 1)
264+
self.wait_until_finished()
244265
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
245266
(len(self.iris.domain.variables), 7))
246267

247268
with patch("Orange.widgets.data.owrank.log.error") as log:
248269
self.send_signal(self.widget.Inputs.scorer, self.lin_reg, 2)
270+
self.wait_until_finished()
249271
log.assert_called()
250272
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
251273
(len(self.iris.domain.variables), 8))
@@ -263,6 +285,7 @@ def test_no_class_data_learner_class_reg(self):
263285

264286
with patch("Orange.widgets.data.owrank.log.error") as log:
265287
self.send_signal(self.widget.Inputs.scorer, random_forest, 1)
288+
self.wait_until_finished()
266289
log.assert_called()
267290

268291
self.assertEqual(self.get_output(self.widget.Outputs.scores).X.shape,
@@ -271,8 +294,10 @@ def test_no_class_data_learner_class_reg(self):
271294
def test_scores_sorting(self):
272295
"""Check clicking on header column orders scores in a different way"""
273296
self.send_signal(self.widget.Inputs.data, self.iris)
297+
self.wait_until_finished()
274298
order1 = self.widget.ranksModel.mapToSourceRows(...).tolist()
275299
self._get_checkbox('FCBF').setChecked(True)
300+
self.wait_until_finished()
276301
self.widget.ranksView.horizontalHeader().setSortIndicator(3, Qt.DescendingOrder)
277302
order2 = self.widget.ranksModel.mapToSourceRows(...).tolist()
278303
self.assertNotEqual(order1, order2)
@@ -282,6 +307,7 @@ def test_scores_nan_sorting(self):
282307
data = self.iris.copy()
283308
data.get_column_view('petal length')[0][:] = np.nan
284309
self.send_signal(self.widget.Inputs.data, data)
310+
self.wait_until_finished()
285311

286312
# Assert last row is all nan
287313
for order in (Qt.AscendingOrder,
@@ -292,6 +318,7 @@ def test_scores_nan_sorting(self):
292318

293319
def test_default_sort_indicator(self):
294320
self.send_signal(self.widget.Inputs.data, self.iris)
321+
self.wait_until_finished()
295322
self.assertNotEqual(
296323
0, self.widget.ranksView.horizontalHeader().sortIndicatorSection())
297324

@@ -354,6 +381,7 @@ def test_auto_selection_manual(self):
354381
data = Table("heart_disease")
355382
dom = data.domain
356383
self.send_signal(w.Inputs.data, data)
384+
self.wait_until_finished()
357385

358386
# Sort by number of values and set selection to attributes with most
359387
# values. This must select the top 4 rows.
@@ -424,6 +452,7 @@ def test_dataset(self):
424452
def test_selected_rows(self):
425453
w = self.widget
426454
self.send_signal(w.Inputs.data, self.iris)
455+
self.wait_until_finished()
427456

428457
# select first and second row
429458
w.selected_rows = [1, 2]
@@ -438,6 +467,7 @@ def test_summary(self):
438467
output_sum = self.widget.info.set_output_summary = Mock()
439468

440469
self.send_signal(self.widget.Inputs.data, data)
470+
self.wait_until_finished()
441471
input_sum.assert_called_with(len(data), format_summary_details(data))
442472
output = self.get_output(self.widget.Outputs.reduced_data)
443473
output_sum.assert_called_with(len(output),
@@ -446,6 +476,7 @@ def test_summary(self):
446476
input_sum.reset_mock()
447477
output_sum.reset_mock()
448478
self.send_signal(self.widget.Inputs.data, None)
479+
self.wait_until_finished()
449480
input_sum.assert_called_once()
450481
self.assertEqual(input_sum.call_args[0][0].brief, "")
451482
output_sum.assert_called_once()

0 commit comments

Comments
 (0)