Skip to content

Commit 194f4a8

Browse files
committed
Rank widget concurrent mixin
1 parent 0c9d951 commit 194f4a8

File tree

2 files changed

+224
-97
lines changed

2 files changed

+224
-97
lines changed

Orange/widgets/data/owrank.py

Lines changed: 170 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,39 @@
1-
import warnings
2-
from collections import namedtuple, OrderedDict
31
import logging
2+
import warnings
3+
from collections import OrderedDict, namedtuple
44
from functools import partial
55
from itertools import chain
6+
from types import SimpleNamespace
7+
from typing import Any, Callable, List, Tuple
68

79
import numpy as np
8-
from scipy.sparse import issparse
9-
10+
from AnyQt.QtCore import (
11+
QItemSelection, QItemSelectionModel, QItemSelectionRange, Qt
12+
)
1013
from AnyQt.QtGui import QFontMetrics
1114
from AnyQt.QtWidgets import (
12-
QTableView, QRadioButton, QButtonGroup, QGridLayout,
13-
QStackedWidget, QHeaderView, QCheckBox, QItemDelegate,
15+
QButtonGroup, QCheckBox, QGridLayout, QHeaderView, QItemDelegate,
16+
QRadioButton, QStackedWidget, QTableView
1417
)
15-
from AnyQt.QtCore import (
16-
Qt, QItemSelection, QItemSelectionRange, QItemSelectionModel,
17-
)
18-
1918
from orangewidget.settings import IncompatibleContext
20-
from Orange.data import (Table, Domain, ContinuousVariable, DiscreteVariable,
21-
StringVariable)
19+
from scipy.sparse import issparse
20+
21+
from Orange.data import (
22+
ContinuousVariable, DiscreteVariable, Domain, StringVariable, Table
23+
)
2224
from Orange.data.util import get_unique_names_duplicates
23-
from Orange.misc.cache import memoize_method
2425
from Orange.preprocess import score
25-
from Orange.widgets import report
26-
from Orange.widgets import gui
27-
from Orange.widgets.settings import (DomainContextHandler, Setting,
28-
ContextSetting)
26+
from Orange.widgets import gui, report
27+
from Orange.widgets.settings import (
28+
ContextSetting, DomainContextHandler, Setting
29+
)
30+
from Orange.widgets.unsupervised.owdistances import InterruptException
31+
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
2932
from Orange.widgets.utils.itemmodels import PyTableModel
3033
from Orange.widgets.utils.sql import check_sql_input
31-
from Orange.widgets.utils.widgetpreview import WidgetPreview
3234
from Orange.widgets.utils.state_summary import format_summary_details
33-
from Orange.widgets.widget import (
34-
OWWidget, Msg, Input, Output, AttributeList
35-
)
36-
35+
from Orange.widgets.utils.widgetpreview import WidgetPreview
36+
from Orange.widgets.widget import AttributeList, Input, Msg, Output, OWWidget
3737

3838
log = logging.getLogger(__name__)
3939

@@ -167,7 +167,79 @@ def _argsortData(self, data, order):
167167
return indices
168168

169169

170-
class OWRank(OWWidget):
170+
class Results(SimpleNamespace):
171+
method_scores: Tuple[ScoreMeta, np.ndarray] = None
172+
scorer_scores: Tuple[ScoreMeta, Tuple[np.ndarray, List[str]]] = None
173+
174+
175+
def get_method_scores(data: Table, method: ScoreMeta) -> np.ndarray:
176+
estimator = method.scorer()
177+
# The widget handles infs and nans.
178+
# Any errors in scorers need to be detected elsewhere.
179+
with np.errstate(all="ignore"):
180+
try:
181+
scores = np.asarray(estimator(data))
182+
except ValueError:
183+
try:
184+
scores = np.array(
185+
[estimator(data, attr) for attr in data.domain.attributes]
186+
)
187+
except ValueError:
188+
log.error("%s doesn't work on this data", method.name)
189+
scores = np.full(len(data.domain.attributes), np.nan)
190+
else:
191+
log.warning(
192+
"%s had to be computed separately for each " "variable",
193+
method.name,
194+
)
195+
return scores
196+
197+
198+
def get_scorer_scores(
199+
data: Table, scorer: ScoreMeta
200+
) -> Tuple[np.ndarray, Tuple[str]]:
201+
try:
202+
scores = scorer.scorer.score_data(data).T
203+
except (ValueError, TypeError):
204+
log.error("%s doesn't work on this data", scorer.name)
205+
scores = np.full((len(data.domain.attributes), 1), np.nan)
206+
207+
labels = (
208+
(scorer.shortname,)
209+
if scores.shape[1] == 1
210+
else tuple(
211+
scorer.shortname + "_" + str(i)
212+
for i in range(1, 1 + scores.shape[1])
213+
)
214+
)
215+
return scores, labels
216+
217+
218+
def run(
219+
data: Table,
220+
methods: List[ScoreMeta],
221+
scorers: List[ScoreMeta],
222+
state: TaskState,
223+
) -> Results:
224+
progress_steps = iter(np.linspace(0, 100, len(methods) + len(scorers)))
225+
226+
def call_with_cb(get_scores: Callable, method: ScoreMeta):
227+
scores = get_scores(data, method)
228+
state.set_progress_value(next(progress_steps))
229+
if state.is_interruption_requested():
230+
raise InterruptException
231+
return scores
232+
233+
method_scores = tuple(
234+
(method, call_with_cb(get_method_scores, method)) for method in methods
235+
)
236+
scorer_scores = tuple(
237+
(scorer, call_with_cb(get_scorer_scores, scorer)) for scorer in scorers
238+
)
239+
return Results(method_scores=method_scores, scorer_scores=scorer_scores)
240+
241+
242+
class OWRank(OWWidget, ConcurrentWidgetMixin):
171243
name = "Rank"
172244
description = "Rank and filter data features by their relevance."
173245
icon = "icons/Rank.svg"
@@ -211,20 +283,23 @@ class Warning(OWWidget.Warning):
211283
renamed_variables = Msg(
212284
"Variables with duplicated names have been renamed.")
213285

214-
215286
def __init__(self):
216-
super().__init__()
287+
OWWidget.__init__(self)
288+
ConcurrentWidgetMixin.__init__(self)
217289
self.scorers = OrderedDict()
218290
self.out_domain_desc = None
219291
self.data = None
220292
self.problem_type_mode = ProblemType.CLASSIFICATION
221293

294+
# results caches
295+
self.scorers_results = {}
296+
self.methods_results = {}
297+
222298
if not self.selected_methods:
223299
self.selected_methods = {method.name for method in SCORES
224300
if method.is_default}
225301

226302
# GUI
227-
228303
self.ranksModel = model = TableModel(parent=self) # type: TableModel
229304
self.ranksView = view = TableView(self) # type: TableView
230305
self.mainArea.layout().addWidget(view)
@@ -312,8 +387,9 @@ def set_data(self, data):
312387
self.ranksModel.clear()
313388
self.ranksModel.resetSorting(True)
314389

315-
self.get_method_scores.cache_clear() # pylint: disable=no-member
316-
self.get_scorer_scores.cache_clear() # pylint: disable=no-member
390+
self.scorers_results = {}
391+
self.methods_results = {}
392+
self.cancel()
317393

318394
self.Error.clear()
319395
self.Information.clear()
@@ -358,7 +434,7 @@ def set_data(self, data):
358434

359435
def handleNewSignals(self):
360436
self.setStatusMessage('Running')
361-
self.updateScores()
437+
self.update_scores()
362438
self.setStatusMessage('')
363439
self.on_select()
364440

@@ -370,86 +446,75 @@ def set_learner(self, scorer, id): # pylint: disable=redefined-builtin
370446
# Avoid caching a (possibly stale) previous instance of the same
371447
# Scorer passed via the same signal
372448
if id in self.scorers:
373-
# pylint: disable=no-member
374-
self.get_scorer_scores.cache_clear()
449+
self.scorers_results = {}
375450

376451
self.scorers[id] = ScoreMeta(scorer.name, scorer.name, scorer,
377452
ProblemType.from_variable(scorer.class_type),
378453
False)
379454

380-
@memoize_method()
381-
def get_method_scores(self, method):
382-
# These errors often happen, but they result in nans, which
383-
# are handled correctly by the widget
384-
estimator = method.scorer()
385-
data = self.data
386-
# The widget handles infs and nans.
387-
# Any errors in scorers need to be detected elsewhere.
388-
with np.errstate(all="ignore"):
389-
try:
390-
scores = np.asarray(estimator(data))
391-
except ValueError:
392-
try:
393-
scores = np.array([estimator(data, attr)
394-
for attr in data.domain.attributes])
395-
except ValueError:
396-
log.error("%s doesn't work on this data", method.name)
397-
scores = np.full(len(data.domain.attributes), np.nan)
398-
else:
399-
log.warning("%s had to be computed separately for each "
400-
"variable", method.name)
401-
return scores
402-
403-
@memoize_method()
404-
def get_scorer_scores(self, scorer):
405-
try:
406-
scores = scorer.scorer.score_data(self.data).T
407-
except (ValueError, TypeError):
408-
log.error("%s doesn't work on this data", scorer.name)
409-
scores = np.full((len(self.data.domain.attributes), 1), np.nan)
410-
411-
labels = ((scorer.shortname,)
412-
if scores.shape[1] == 1 else
413-
tuple(scorer.shortname + '_' + str(i)
414-
for i in range(1, 1 + scores.shape[1])))
415-
return scores, labels
416-
417-
def updateScores(self):
455+
def _get_methods(self):
456+
return [
457+
method
458+
for method in SCORES
459+
if (
460+
method.name in self.selected_methods
461+
and method.problem_type == self.problem_type_mode
462+
and (
463+
not issparse(self.data.X)
464+
or method.scorer.supports_sparse_data
465+
)
466+
)
467+
]
468+
469+
def _get_scorers(self):
470+
scorers = []
471+
for scorer in self.scorers.values():
472+
if scorer.problem_type in (
473+
self.problem_type_mode,
474+
ProblemType.UNSUPERVISED,
475+
):
476+
scorers.append(scorer)
477+
else:
478+
self.Error.inadequate_learner(
479+
scorer.name, scorer.learner_adequacy_err_msg
480+
)
481+
return scorers
482+
483+
def update_scores(self):
418484
if self.data is None:
419485
self.ranksModel.clear()
420486
self.Outputs.scores.send(None)
421487
return
422488

423-
methods = [method
424-
for method in SCORES
425-
if (method.name in self.selected_methods and
426-
method.problem_type == self.problem_type_mode and
427-
(not issparse(self.data.X) or
428-
method.scorer.supports_sparse_data))]
429-
430-
scorers = []
431489
self.Error.inadequate_learner.clear()
432-
for scorer in self.scorers.values():
433-
if scorer.problem_type in (self.problem_type_mode, ProblemType.UNSUPERVISED):
434-
scorers.append(scorer)
435-
else:
436-
self.Error.inadequate_learner(scorer.name, scorer.learner_adequacy_err_msg)
437490

438-
method_scores = tuple(self.get_method_scores(method)
439-
for method in methods)
491+
scorers = [
492+
s for s in self._get_scorers() if s not in self.scorers_results
493+
]
494+
methods = [
495+
m for m in self._get_methods() if m not in self.methods_results
496+
]
497+
self.start(run, self.data, methods, scorers)
440498

441-
scorer_scores, scorer_labels = (), ()
442-
if scorers:
443-
scorer_scores, scorer_labels = zip(*(self.get_scorer_scores(scorer)
444-
for scorer in scorers))
445-
scorer_labels = tuple(chain.from_iterable(scorer_labels))
499+
def on_done(self, result: Results) -> None:
500+
self.methods_results.update(result.method_scores)
501+
self.scorers_results.update(result.scorer_scores)
446502

447-
labels = tuple(method.shortname for method in methods) + scorer_labels
503+
methods = self._get_methods()
504+
method_labels = tuple(m.shortname for m in methods)
505+
method_scores = tuple(self.methods_results[m] for m in methods)
506+
507+
scores = [self.scorers_results[s] for s in self._get_scorers()]
508+
scorer_scores, scorer_labels = zip(*scores) if scores else ((), ())
509+
510+
labels = method_labels + tuple(chain.from_iterable(scorer_labels))
448511
model_array = np.column_stack(
449-
([len(a.values) if a.is_discrete else np.nan
450-
for a in self.data.domain.attributes],) +
451-
(method_scores if method_scores else ()) +
452-
(scorer_scores if scorer_scores else ())
512+
(
513+
[len(a.values) if a.is_discrete else np.nan
514+
for a in self.data.domain.attributes],
515+
)
516+
+ method_scores
517+
+ scorer_scores
453518
)
454519
for column, values in enumerate(model_array.T):
455520
self.ranksModel.setExtremesFrom(column, values)
@@ -464,13 +529,21 @@ def updateScores(self):
464529
if sort_column < len(labels):
465530
# adds 1 for '#' (discrete count) column
466531
self.ranksModel.sort(sort_column + 1, sort_order)
467-
self.ranksView.horizontalHeader().setSortIndicator(sort_column + 1, sort_order)
532+
self.ranksView.horizontalHeader().setSortIndicator(
533+
sort_column + 1, sort_order
534+
)
468535
except ValueError:
469536
pass
470537

471538
self.autoSelection()
472539
self.Outputs.scores.send(self.create_scores_table(labels))
473540

541+
def on_exception(self, ex: Exception) -> None:
542+
raise ex
543+
544+
def on_partial_result(self, result: Any) -> None:
545+
pass
546+
474547
def on_select(self):
475548
# Save indices of attributes in the original, unsorted domain
476549
selected_rows = self.ranksView.selectionModel().selectedRows(0)
@@ -530,7 +603,7 @@ def methodSelectionChanged(self, state, method_name):
530603
elif method_name in self.selected_methods:
531604
self.selected_methods.remove(method_name)
532605

533-
self.updateScores()
606+
self.update_scores()
534607

535608
def send_report(self):
536609
if not self.data:
@@ -621,4 +694,4 @@ def migrate_context(cls, context, version):
621694
WidgetPreview(OWRank).run(
622695
set_learner=(RandomForestLearner(), (3, 'Learner', None)),
623696
set_data=Table("heart_disease.tab"))
624-
"""
697+
"""

0 commit comments

Comments
 (0)