Skip to content

Commit 2fbe656

Browse files
authored
Merge pull request #1981 from VesnaT/svm_coeff
[ENH] OWSGD: Output coefficients
2 parents 5970d21 + 6cad0b3 commit 2fbe656

File tree

7 files changed

+85
-23
lines changed

7 files changed

+85
-23
lines changed

Orange/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .simple_random_forest import *
1515
from .elliptic_envelope import *
1616
from .rules import *
17+
from .sgd import *

Orange/classification/sgd.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from sklearn.linear_model import SGDClassifier
2-
from sklearn.pipeline import Pipeline
3-
from sklearn.preprocessing import StandardScaler
42

53
from Orange.base import SklLearner
4+
from Orange.preprocess import Normalize
65
from Orange.regression.linear import LinearModel
76

7+
__all__ = ["SGDClassificationLearner"]
8+
89

910
class SGDClassificationLearner(SklLearner):
1011
name = 'sgd'
1112
__wraps__ = SGDClassifier
13+
__returns__ = LinearModel
14+
preprocessors = SklLearner.preprocessors + [Normalize()]
1215

1316
def __init__(self, loss='squared_loss',penalty='l2', alpha=0.0001,
1417
l1_ratio=0.15,fit_intercept=True, n_iter=5, shuffle=True,
@@ -17,9 +20,3 @@ def __init__(self, loss='squared_loss',penalty='l2', alpha=0.0001,
1720
preprocessors=None):
1821
super().__init__(preprocessors=preprocessors)
1922
self.params = vars()
20-
21-
def fit(self, X, Y, W):
22-
sk = self.__wraps__(**self.params)
23-
clf = Pipeline([('scaler', StandardScaler()), ('sgd', sk)])
24-
clf.fit(X, Y.ravel())
25-
return LinearModel(clf)

Orange/regression/linear.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import numpy as np
22

33
import sklearn.linear_model as skl_linear_model
4-
import sklearn.pipeline as skl_pipeline
54
import sklearn.preprocessing as skl_preprocessing
65

76
from Orange.data import Variable, ContinuousVariable
8-
from Orange.preprocess import Continuize, Normalize, RemoveNaNColumns, SklImpute
7+
from Orange.preprocess import Normalize
98
from Orange.preprocess.score import LearnerScorer
109
from Orange.regression import Learner, Model, SklLearner, SklModel
1110

@@ -82,6 +81,7 @@ def __init__(self, l1_ratio=0.5, eps=0.001, n_alphas=100, alphas=None,
8281

8382
class SGDRegressionLearner(LinearRegressionLearner):
8483
__wraps__ = skl_linear_model.SGDRegressor
84+
preprocessors = SklLearner.preprocessors + [Normalize()]
8585

8686
def __init__(self, loss='squared_loss',penalty='l2', alpha=0.0001,
8787
l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=True,
@@ -92,13 +92,6 @@ def __init__(self, loss='squared_loss',penalty='l2', alpha=0.0001,
9292
super().__init__(preprocessors=preprocessors)
9393
self.params = vars()
9494

95-
def fit(self, X, Y, W):
96-
sk = self.__wraps__(**self.params)
97-
clf = skl_pipeline.Pipeline(
98-
[('scaler', skl_preprocessing.StandardScaler()), ('sgd', sk)])
99-
clf.fit(X, Y.ravel())
100-
return LinearModel(clf)
101-
10295

10396
class PolynomialLearner(Learner):
10497
"""Generate polynomial features and learn a prediction model

Orange/tests/test_sgd.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,39 @@
55

66
import numpy as np
77

8-
import Orange
8+
from Orange.data import Table
9+
from Orange.classification import SGDClassificationLearner
10+
from Orange.regression import SGDRegressionLearner
11+
from Orange.evaluation import CrossValidation, RMSE, AUC
912

1013

1114
class TestSGDRegressionLearner(unittest.TestCase):
1215
def test_SGDRegression(self):
1316
nrows, ncols = 500, 5
1417
X = np.random.rand(nrows, ncols)
1518
y = X.dot(np.random.rand(ncols))
16-
data = Orange.data.Table(X, y)
17-
sgd = Orange.regression.SGDRegressionLearner()
18-
res = Orange.evaluation.CrossValidation(data, [sgd], k=3)
19-
self.assertLess(Orange.evaluation.RMSE(res)[0], 0.1)
19+
data = Table(X, y)
20+
sgd = SGDRegressionLearner()
21+
res = CrossValidation(data, [sgd], k=3)
22+
self.assertLess(RMSE(res)[0], 0.1)
23+
24+
def test_coefficients(self):
25+
lrn = SGDRegressionLearner()
26+
mod = lrn(Table("housing"))
27+
self.assertEqual(len(mod.coefficients), len(mod.domain.attributes))
28+
29+
30+
class TestSGDClassificationLearner(unittest.TestCase):
31+
@classmethod
32+
def setUpClass(cls):
33+
cls.iris = Table('iris')
34+
35+
def test_SGDClassification(self):
36+
sgd = SGDClassificationLearner()
37+
res = CrossValidation(self.iris, [sgd], k=3)
38+
self.assertGreater(AUC(res)[0], 0.85)
39+
40+
def test_coefficients(self):
41+
lrn = SGDClassificationLearner()
42+
mod = lrn(self.iris)
43+
self.assertEqual(len(mod.coefficients[0]), len(mod.domain.attributes))

Orange/widgets/classify/tests/test_owsgd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table
34
from Orange.widgets.classify.owsgd import OWSGD
45
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin, \
56
ParameterMapping
@@ -23,3 +24,16 @@ def setUp(self):
2324
ParameterMapping.from_attribute(self.widget, 'eta0'),
2425
ParameterMapping.from_attribute(self.widget, 'power_t'),
2526
]
27+
28+
def test_output_coefficients(self):
29+
"""Check if coefficients are on output after apply"""
30+
self.assertIsNone(self.get_output("Coefficients"))
31+
self.send_signal("Data", self.data)
32+
self.widget.apply_button.button.click()
33+
coeffs = self.get_output("Coefficients")
34+
self.assertIsInstance(coeffs, Table)
35+
domain = self.data.domain
36+
self.assertEqual(coeffs.X.shape, (len(domain.attributes) + 1,
37+
len(domain.class_var.values)))
38+
self.send_signal("Data", None)
39+
self.assertIsNone(self.get_output("Coefficients"))

Orange/widgets/model/owsgd.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from AnyQt.QtCore import Qt
44

55
from Orange.canvas.report import bool_str
6+
from Orange.data import ContinuousVariable, StringVariable, Domain, Table
67
from Orange.modelling.linear import SGDLearner
7-
from Orange.widgets import gui
8+
from Orange.widgets import gui, widget
9+
from Orange.widgets.classify.owlogisticregression import create_coef_table
810
from Orange.widgets.settings import Setting
911
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
1012

@@ -20,6 +22,8 @@ class OWSGD(OWBaseLearner):
2022

2123
LEARNER = SGDLearner
2224

25+
outputs = [("Coefficients", Table, widget.Explicit)]
26+
2327
reg_losses = (
2428
('Squared Loss', 'squared_loss'),
2529
('Huber', 'huber'),
@@ -281,6 +285,22 @@ def get_learner_parameters(self):
281285

282286
return list(params.items())
283287

288+
def update_model(self):
289+
super().update_model()
290+
coeffs = None
291+
if self.model is not None:
292+
if self.model.domain.class_var.is_discrete:
293+
coeffs = create_coef_table(self.model)
294+
else:
295+
attrs = [ContinuousVariable("coef", number_of_decimals=7)]
296+
domain = Domain(attrs, metas=[StringVariable("name")])
297+
cfs = list(self.model.intercept) + list(self.model.coefficients)
298+
names = ["intercept"] + \
299+
[attr.name for attr in self.model.domain.attributes]
300+
coeffs = Table(domain, list(zip(cfs, names)))
301+
coeffs.name = "coefficients"
302+
self.send("Coefficients", coeffs)
303+
284304

285305
if __name__ == '__main__':
286306
import sys

Orange/widgets/regression/tests/test_owsgdregression.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table
34
from Orange.widgets.regression.owsgdregression import OWSGD
45
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin, \
56
ParameterMapping
@@ -24,3 +25,15 @@ def setUp(self):
2425
ParameterMapping.from_attribute(self.widget, 'eta0'),
2526
ParameterMapping.from_attribute(self.widget, 'power_t'),
2627
]
28+
29+
def test_output_coefficients(self):
30+
"""Check if coefficients are on output after apply"""
31+
self.assertIsNone(self.get_output("Coefficients"))
32+
self.send_signal("Data", self.data)
33+
self.widget.apply_button.button.click()
34+
coeffs = self.get_output("Coefficients")
35+
self.assertIsInstance(coeffs, Table)
36+
self.assertEqual(coeffs.X.shape,
37+
(len(self.data.domain.attributes) + 1, 1))
38+
self.send_signal("Data", None)
39+
self.assertIsNone(self.get_output("Coefficients"))

0 commit comments

Comments
 (0)