1- import warnings
2- from collections import namedtuple , OrderedDict
31import logging
2+ import warnings
3+ from collections import OrderedDict , namedtuple
44from functools import partial
55from itertools import chain
6+ from types import SimpleNamespace
7+ from typing import Any , Callable , List , Tuple
68
79import numpy as np
8- from scipy .sparse import issparse
9-
10+ from AnyQt .QtCore import (
11+ QItemSelection , QItemSelectionModel , QItemSelectionRange , Qt
12+ )
1013from AnyQt .QtGui import QFontMetrics
1114from AnyQt .QtWidgets import (
12- QTableView , QRadioButton , QButtonGroup , QGridLayout ,
13- QStackedWidget , QHeaderView , QCheckBox , QItemDelegate ,
15+ QButtonGroup , QCheckBox , QGridLayout , QHeaderView , QItemDelegate ,
16+ QRadioButton , QStackedWidget , QTableView
1417)
15- from AnyQt .QtCore import (
16- Qt , QItemSelection , QItemSelectionRange , QItemSelectionModel ,
17- )
18-
1918from orangewidget .settings import IncompatibleContext
20- from Orange .data import (Table , Domain , ContinuousVariable , DiscreteVariable ,
21- StringVariable )
19+ from scipy .sparse import issparse
20+
21+ from Orange .data import (
22+ ContinuousVariable , DiscreteVariable , Domain , StringVariable , Table
23+ )
2224from Orange .data .util import get_unique_names_duplicates
23- from Orange .misc .cache import memoize_method
2425from Orange .preprocess import score
25- from Orange .widgets import report
26- from Orange .widgets import gui
27- from Orange .widgets .settings import (DomainContextHandler , Setting ,
28- ContextSetting )
26+ from Orange .widgets import gui , report
27+ from Orange .widgets .settings import (
28+ ContextSetting , DomainContextHandler , Setting
29+ )
30+ from Orange .widgets .unsupervised .owdistances import InterruptException
31+ from Orange .widgets .utils .concurrent import ConcurrentWidgetMixin , TaskState
2932from Orange .widgets .utils .itemmodels import PyTableModel
3033from Orange .widgets .utils .sql import check_sql_input
31- from Orange .widgets .utils .widgetpreview import WidgetPreview
3234from Orange .widgets .utils .state_summary import format_summary_details
33- from Orange .widgets .widget import (
34- OWWidget , Msg , Input , Output , AttributeList
35- )
36-
35+ from Orange .widgets .utils .widgetpreview import WidgetPreview
36+ from Orange .widgets .widget import AttributeList , Input , Msg , Output , OWWidget
3737
3838log = logging .getLogger (__name__ )
3939
@@ -167,7 +167,79 @@ def _argsortData(self, data, order):
167167 return indices
168168
169169
170- class OWRank (OWWidget ):
170+ class Results (SimpleNamespace ):
171+ method_scores : Tuple [ScoreMeta , np .ndarray ] = None
172+ scorer_scores : Tuple [ScoreMeta , Tuple [np .ndarray , List [str ]]] = None
173+
174+
175+ def get_method_scores (data : Table , method : ScoreMeta ) -> np .ndarray :
176+ estimator = method .scorer ()
177+ # The widget handles infs and nans.
178+ # Any errors in scorers need to be detected elsewhere.
179+ with np .errstate (all = "ignore" ):
180+ try :
181+ scores = np .asarray (estimator (data ))
182+ except ValueError :
183+ try :
184+ scores = np .array (
185+ [estimator (data , attr ) for attr in data .domain .attributes ]
186+ )
187+ except ValueError :
188+ log .error ("%s doesn't work on this data" , method .name )
189+ scores = np .full (len (data .domain .attributes ), np .nan )
190+ else :
191+ log .warning (
192+ "%s had to be computed separately for each " "variable" ,
193+ method .name ,
194+ )
195+ return scores
196+
197+
198+ def get_scorer_scores (
199+ data : Table , scorer : ScoreMeta
200+ ) -> Tuple [np .ndarray , Tuple [str ]]:
201+ try :
202+ scores = scorer .scorer .score_data (data ).T
203+ except (ValueError , TypeError ):
204+ log .error ("%s doesn't work on this data" , scorer .name )
205+ scores = np .full ((len (data .domain .attributes ), 1 ), np .nan )
206+
207+ labels = (
208+ (scorer .shortname ,)
209+ if scores .shape [1 ] == 1
210+ else tuple (
211+ scorer .shortname + "_" + str (i )
212+ for i in range (1 , 1 + scores .shape [1 ])
213+ )
214+ )
215+ return scores , labels
216+
217+
218+ def run (
219+ data : Table ,
220+ methods : List [ScoreMeta ],
221+ scorers : List [ScoreMeta ],
222+ state : TaskState ,
223+ ) -> Results :
224+ progress_steps = iter (np .linspace (0 , 100 , len (methods ) + len (scorers )))
225+
226+ def call_with_cb (get_scores : Callable , method : ScoreMeta ):
227+ scores = get_scores (data , method )
228+ state .set_progress_value (next (progress_steps ))
229+ if state .is_interruption_requested ():
230+ raise InterruptException
231+ return scores
232+
233+ method_scores = tuple (
234+ (method , call_with_cb (get_method_scores , method )) for method in methods
235+ )
236+ scorer_scores = tuple (
237+ (scorer , call_with_cb (get_scorer_scores , scorer )) for scorer in scorers
238+ )
239+ return Results (method_scores = method_scores , scorer_scores = scorer_scores )
240+
241+
242+ class OWRank (OWWidget , ConcurrentWidgetMixin ):
171243 name = "Rank"
172244 description = "Rank and filter data features by their relevance."
173245 icon = "icons/Rank.svg"
@@ -211,20 +283,23 @@ class Warning(OWWidget.Warning):
211283 renamed_variables = Msg (
212284 "Variables with duplicated names have been renamed." )
213285
214-
215286 def __init__ (self ):
216- super ().__init__ ()
287+ OWWidget .__init__ (self )
288+ ConcurrentWidgetMixin .__init__ (self )
217289 self .scorers = OrderedDict ()
218290 self .out_domain_desc = None
219291 self .data = None
220292 self .problem_type_mode = ProblemType .CLASSIFICATION
221293
294+ # results caches
295+ self .scorers_results = {}
296+ self .methods_results = {}
297+
222298 if not self .selected_methods :
223299 self .selected_methods = {method .name for method in SCORES
224300 if method .is_default }
225301
226302 # GUI
227-
228303 self .ranksModel = model = TableModel (parent = self ) # type: TableModel
229304 self .ranksView = view = TableView (self ) # type: TableView
230305 self .mainArea .layout ().addWidget (view )
@@ -312,8 +387,9 @@ def set_data(self, data):
312387 self .ranksModel .clear ()
313388 self .ranksModel .resetSorting (True )
314389
315- self .get_method_scores .cache_clear () # pylint: disable=no-member
316- self .get_scorer_scores .cache_clear () # pylint: disable=no-member
390+ self .scorers_results = {}
391+ self .methods_results = {}
392+ self .cancel ()
317393
318394 self .Error .clear ()
319395 self .Information .clear ()
@@ -358,7 +434,7 @@ def set_data(self, data):
358434
359435 def handleNewSignals (self ):
360436 self .setStatusMessage ('Running' )
361- self .updateScores ()
437+ self .update_scores ()
362438 self .setStatusMessage ('' )
363439 self .on_select ()
364440
@@ -370,86 +446,75 @@ def set_learner(self, scorer, id): # pylint: disable=redefined-builtin
370446 # Avoid caching a (possibly stale) previous instance of the same
371447 # Scorer passed via the same signal
372448 if id in self .scorers :
373- # pylint: disable=no-member
374- self .get_scorer_scores .cache_clear ()
449+ self .scorers_results = {}
375450
376451 self .scorers [id ] = ScoreMeta (scorer .name , scorer .name , scorer ,
377452 ProblemType .from_variable (scorer .class_type ),
378453 False )
379454
380- @memoize_method ()
381- def get_method_scores (self , method ):
382- # These errors often happen, but they result in nans, which
383- # are handled correctly by the widget
384- estimator = method .scorer ()
385- data = self .data
386- # The widget handles infs and nans.
387- # Any errors in scorers need to be detected elsewhere.
388- with np .errstate (all = "ignore" ):
389- try :
390- scores = np .asarray (estimator (data ))
391- except ValueError :
392- try :
393- scores = np .array ([estimator (data , attr )
394- for attr in data .domain .attributes ])
395- except ValueError :
396- log .error ("%s doesn't work on this data" , method .name )
397- scores = np .full (len (data .domain .attributes ), np .nan )
398- else :
399- log .warning ("%s had to be computed separately for each "
400- "variable" , method .name )
401- return scores
402-
403- @memoize_method ()
404- def get_scorer_scores (self , scorer ):
405- try :
406- scores = scorer .scorer .score_data (self .data ).T
407- except (ValueError , TypeError ):
408- log .error ("%s doesn't work on this data" , scorer .name )
409- scores = np .full ((len (self .data .domain .attributes ), 1 ), np .nan )
410-
411- labels = ((scorer .shortname ,)
412- if scores .shape [1 ] == 1 else
413- tuple (scorer .shortname + '_' + str (i )
414- for i in range (1 , 1 + scores .shape [1 ])))
415- return scores , labels
416-
417- def updateScores (self ):
455+ def _get_methods (self ):
456+ return [
457+ method
458+ for method in SCORES
459+ if (
460+ method .name in self .selected_methods
461+ and method .problem_type == self .problem_type_mode
462+ and (
463+ not issparse (self .data .X )
464+ or method .scorer .supports_sparse_data
465+ )
466+ )
467+ ]
468+
469+ def _get_scorers (self ):
470+ scorers = []
471+ for scorer in self .scorers .values ():
472+ if scorer .problem_type in (
473+ self .problem_type_mode ,
474+ ProblemType .UNSUPERVISED ,
475+ ):
476+ scorers .append (scorer )
477+ else :
478+ self .Error .inadequate_learner (
479+ scorer .name , scorer .learner_adequacy_err_msg
480+ )
481+ return scorers
482+
483+ def update_scores (self ):
418484 if self .data is None :
419485 self .ranksModel .clear ()
420486 self .Outputs .scores .send (None )
421487 return
422488
423- methods = [method
424- for method in SCORES
425- if (method .name in self .selected_methods and
426- method .problem_type == self .problem_type_mode and
427- (not issparse (self .data .X ) or
428- method .scorer .supports_sparse_data ))]
429-
430- scorers = []
431489 self .Error .inadequate_learner .clear ()
432- for scorer in self .scorers .values ():
433- if scorer .problem_type in (self .problem_type_mode , ProblemType .UNSUPERVISED ):
434- scorers .append (scorer )
435- else :
436- self .Error .inadequate_learner (scorer .name , scorer .learner_adequacy_err_msg )
437490
438- method_scores = tuple (self .get_method_scores (method )
439- for method in methods )
491+ scorers = [
492+ s for s in self ._get_scorers () if s not in self .scorers_results
493+ ]
494+ methods = [
495+ m for m in self ._get_methods () if m not in self .methods_results
496+ ]
497+ self .start (run , self .data , methods , scorers )
440498
441- scorer_scores , scorer_labels = (), ()
442- if scorers :
443- scorer_scores , scorer_labels = zip (* (self .get_scorer_scores (scorer )
444- for scorer in scorers ))
445- scorer_labels = tuple (chain .from_iterable (scorer_labels ))
499+ def on_done (self , result : Results ) -> None :
500+ self .methods_results .update (result .method_scores )
501+ self .scorers_results .update (result .scorer_scores )
446502
447- labels = tuple (method .shortname for method in methods ) + scorer_labels
503+ methods = self ._get_methods ()
504+ method_labels = tuple (m .shortname for m in methods )
505+ method_scores = tuple (self .methods_results [m ] for m in methods )
506+
507+ scores = [self .scorers_results [s ] for s in self ._get_scorers ()]
508+ scorer_scores , scorer_labels = zip (* scores ) if scores else ((), ())
509+
510+ labels = method_labels + tuple (chain .from_iterable (scorer_labels ))
448511 model_array = np .column_stack (
449- ([len (a .values ) if a .is_discrete else np .nan
450- for a in self .data .domain .attributes ],) +
451- (method_scores if method_scores else ()) +
452- (scorer_scores if scorer_scores else ())
512+ (
513+ [len (a .values ) if a .is_discrete else np .nan
514+ for a in self .data .domain .attributes ],
515+ )
516+ + method_scores
517+ + scorer_scores
453518 )
454519 for column , values in enumerate (model_array .T ):
455520 self .ranksModel .setExtremesFrom (column , values )
@@ -464,13 +529,21 @@ def updateScores(self):
464529 if sort_column < len (labels ):
465530 # adds 1 for '#' (discrete count) column
466531 self .ranksModel .sort (sort_column + 1 , sort_order )
467- self .ranksView .horizontalHeader ().setSortIndicator (sort_column + 1 , sort_order )
532+ self .ranksView .horizontalHeader ().setSortIndicator (
533+ sort_column + 1 , sort_order
534+ )
468535 except ValueError :
469536 pass
470537
471538 self .autoSelection ()
472539 self .Outputs .scores .send (self .create_scores_table (labels ))
473540
541+ def on_exception (self , ex : Exception ) -> None :
542+ raise ex
543+
544+ def on_partial_result (self , result : Any ) -> None :
545+ pass
546+
474547 def on_select (self ):
475548 # Save indices of attributes in the original, unsorted domain
476549 selected_rows = self .ranksView .selectionModel ().selectedRows (0 )
@@ -530,7 +603,7 @@ def methodSelectionChanged(self, state, method_name):
530603 elif method_name in self .selected_methods :
531604 self .selected_methods .remove (method_name )
532605
533- self .updateScores ()
606+ self .update_scores ()
534607
535608 def send_report (self ):
536609 if not self .data :
@@ -621,4 +694,4 @@ def migrate_context(cls, context, version):
621694 WidgetPreview(OWRank).run(
622695 set_learner=(RandomForestLearner(), (3, 'Learner', None)),
623696 set_data=Table("heart_disease.tab"))
624- """
697+ """
0 commit comments