Skip to content

Commit 9d4e8af

Browse files
authored
Merge pull request #2350 from kernc/rank-learners
[FIX] Rank widget supports Scorer inputs
2 parents 4302132 + 67886bd commit 9d4e8af

File tree

10 files changed

+103
-18
lines changed

10 files changed

+103
-18
lines changed

Orange/modelling/base.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from Orange.base import Learner, Model, SklLearner
4+
from Orange.data import Table, Domain
45

56

67
class Fitter(Learner):
@@ -31,10 +32,7 @@ def __init__(self, preprocessors=None, **kwargs):
3132
self.__learners = {self.CLASSIFICATION: None, self.REGRESSION: None}
3233

3334
def _fit_model(self, data):
34-
if data.domain.has_discrete_class:
35-
learner = self.get_learner(self.CLASSIFICATION)
36-
else:
37-
learner = self.get_learner(self.REGRESSION)
35+
learner = self.get_learner(data)
3836

3937
if type(self).fit is Learner.fit:
4038
return learner.fit_storage(data)
@@ -43,20 +41,34 @@ def _fit_model(self, data):
4341
return learner.fit(X, Y, W)
4442

4543
def preprocess(self, data):
46-
if data.domain.has_discrete_class:
47-
return self.get_learner(self.CLASSIFICATION).preprocess(data)
48-
else:
49-
return self.get_learner(self.REGRESSION).preprocess(data)
44+
return self.get_learner(data).preprocess(data)
5045

5146
def get_learner(self, problem_type):
5247
"""Get the learner for a given problem type.
5348
49+
Parameters
50+
----------
51+
problem_type: str or Table or Domain
52+
If str, one of ``'classification'`` or ``'regression'``. If Table
53+
or Domain, the type is inferred from Domain's first class variable.
54+
5455
Returns
5556
-------
5657
Learner
5758
The appropriate learner for the given problem type.
5859
60+
Raises
61+
------
62+
TypeError
63+
When (inferred) problem type not one of ``'classification'``
64+
or ``'regression'``.
5965
"""
66+
if isinstance(problem_type, Table):
67+
problem_type = problem_type.domain
68+
if isinstance(problem_type, Domain):
69+
problem_type = (self.CLASSIFICATION if problem_type.has_discrete_class else
70+
self.REGRESSION if problem_type.has_continuous_class else
71+
None)
6072
# Prevent trying to access the learner when problem type is None
6173
if problem_type not in self.__fits__:
6274
raise TypeError("No learner to handle '{}'".format(problem_type))
@@ -112,8 +124,5 @@ class SklFitter(Fitter):
112124
def _fit_model(self, data):
113125
model = super()._fit_model(data)
114126
model.used_vals = [np.unique(y) for y in data.Y[:, None].T]
115-
if data.domain.has_discrete_class:
116-
model.params = self.get_params(self.CLASSIFICATION)
117-
else:
118-
model.params = self.get_params(self.REGRESSION)
127+
model.params = self.get_params(data)
119128
return model

Orange/modelling/linear.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
1+
import numpy as np
2+
13
from Orange.classification.sgd import SGDClassificationLearner
4+
from Orange.data import Variable
25
from Orange.modelling import SklFitter
6+
from Orange.preprocess.score import LearnerScorer
37
from Orange.regression import SGDRegressionLearner
48

59
__all__ = ['SGDLearner']
610

711

8-
class SGDLearner(SklFitter):
12+
class _FeatureScorerMixin(LearnerScorer):
13+
feature_type = Variable
14+
class_type = Variable
15+
16+
def score(self, data):
17+
model = self.get_learner(data)(data)
18+
return np.atleast_2d(np.abs(model.skl_model.coef_)).mean(0)
19+
20+
21+
class SGDLearner(SklFitter, _FeatureScorerMixin):
922
name = 'sgd'
1023

1124
__fits__ = {'classification': SGDClassificationLearner,

Orange/modelling/randomforest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
from Orange.base import RandomForestModel
22
from Orange.classification import RandomForestLearner as RFClassification
3+
from Orange.data import Variable
34
from Orange.modelling import SklFitter
5+
from Orange.preprocess.score import LearnerScorer
46
from Orange.regression import RandomForestRegressionLearner as RFRegression
57

68
__all__ = ['RandomForestLearner']
79

810

9-
class RandomForestLearner(SklFitter):
11+
class _FeatureScorerMixin(LearnerScorer):
12+
feature_type = Variable
13+
class_type = Variable
14+
15+
def score(self, data):
16+
model = self.get_learner(data)(data)
17+
return model.skl_model.feature_importances_
18+
19+
20+
class RandomForestLearner(SklFitter, _FeatureScorerMixin):
1021
name = 'random forest'
1122

1223
__fits__ = {'classification': RFClassification,

Orange/preprocess/score.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,10 @@ def score(self, data):
154154
raise NotImplementedError
155155

156156
def score_data(self, data, feature=None):
157-
scores = self.score(data)
158157

159158
def average_scores(scores):
160159
scores_grouped = defaultdict(list)
161-
for attr, score in zip(self.domain.attributes, scores):
160+
for attr, score in zip(model_domain.attributes, scores):
162161
# Go up the chain of preprocessors to obtain the original variable
163162
while getattr(attr, 'compute_value', False):
164163
attr = getattr(attr.compute_value, 'variable', attr)
@@ -167,8 +166,14 @@ def average_scores(scores):
167166
if attr in scores_grouped else 0
168167
for attr in data.domain.attributes]
169168

170-
scores = np.atleast_2d(scores)
171-
if data.domain != self.domain:
169+
scores = np.atleast_2d(self.score(data))
170+
171+
from Orange.modelling import Fitter # Avoid recursive import
172+
model_domain = (self.get_learner(data).domain
173+
if isinstance(self, Fitter) else
174+
self.domain)
175+
176+
if data.domain != model_domain:
172177
scores = np.array([average_scores(row) for row in scores])
173178

174179
return scores[:, data.domain.attributes.index(feature)] \

Orange/widgets/data/tests/test_owrank.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import numpy as np
22

33
from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
4+
from Orange.modelling import RandomForestLearner, SGDLearner
45
from Orange.preprocess.score import Scorer
56
from Orange.classification import LogisticRegressionLearner
67
from Orange.regression import LinearRegressionLearner
78
from Orange.projection import PCA
89
from Orange.widgets.data.owrank import OWRank
910
from Orange.widgets.tests.base import WidgetTest
1011

12+
from AnyQt.QtCore import Qt
13+
1114

1215
class TestOWRank(WidgetTest):
1316
def setUp(self):
@@ -39,6 +42,31 @@ def test_input_scorer(self):
3942
self.assertEqual(self.log_reg, value.score)
4043
self.assertIsInstance(value.score, Scorer)
4144

45+
def test_input_scorer_fitter(self):
46+
heart_disease = Table('heart_disease')
47+
self.assertEqual(self.widget.learners, {})
48+
49+
for fitter, name in ((RandomForestLearner(), 'random forest'),
50+
(SGDLearner(), 'sgd')):
51+
with self.subTest(fitter=fitter):
52+
self.send_signal("Scorer", fitter, 1)
53+
54+
for data, model in ((self.housing, self.widget.contRanksModel),
55+
(heart_disease, self.widget.discRanksModel)):
56+
with self.subTest(data=data.name):
57+
self.send_signal('Data', data)
58+
scores = [model.data(model.index(row, model.columnCount() - 1))
59+
for row in range(model.rowCount())]
60+
self.assertEqual(len(scores), len(data.domain.attributes))
61+
self.assertFalse(np.isnan(scores).any())
62+
63+
last_column = model.headerData(
64+
model.columnCount() - 1, Qt.Horizontal).lower()
65+
self.assertIn(name, last_column)
66+
67+
self.send_signal("Scorer", None, 1)
68+
self.assertEqual(self.widget.learners, {})
69+
4270
def test_input_scorer_disconnect(self):
4371
"""Check widget's scorer after disconnecting scorer on the input"""
4472
self.send_signal(self.widget.Inputs.scorer, self.log_reg, 1)

doc/visual-programming/source/widgets/data/rank.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ Signals
1414

1515
An input data set.
1616

17+
- **Scorer** (multiple)
18+
19+
Models that implement the feature scoring interface, such as linear /
20+
logistic regression, random forest, stochastic gradient descent, etc.
21+
1722
**Outputs**:
1823

1924
- **Reduced Data**
@@ -47,6 +52,12 @@ Scoring methods
4752
6. `ReliefF <https://en.wikipedia.org/wiki/Relief_(feature_selection)>`_: the ability of an attribute to distinguish between classes on similar data instances
4853
7. `FCBF (Fast Correlation Based Filter) <https://www.aaai.org/Papers/ICML/2003/ICML03-111.pdf>`_: entropy-based measure, which also identifies redundancy due to pairwise correlations between features
4954

55+
Additionally, you can connect certain learners that enable scoring the features
56+
according to how important they are in models that the learners build (e.g.
57+
:ref:`Linear <model.lr>` / :ref:`Logistic Regression <model.logit>`,
58+
:ref:`Random Forest <model.rf>`, :ref:`SGD <model.sgd>`, …).
59+
60+
5061
Example: Attribute Ranking and Selection
5162
----------------------------------------
5263

doc/visual-programming/source/widgets/model/linearregression.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _model.lr:
2+
13
Linear Regression
24
=================
35

doc/visual-programming/source/widgets/model/logisticregression.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _model.logit:
2+
13
Logistic Regression
24
===================
35

doc/visual-programming/source/widgets/model/randomforest.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _model.rf:
2+
13
Random Forest
24
=============
35

doc/visual-programming/source/widgets/model/stochasticgradient.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
.. _model.sgd:
2+
13
Stochastic Gradient Descent
24
===========================
35

0 commit comments

Comments
 (0)