Skip to content

Commit 6cad0b3

Browse files
committed
OWSGD: Output coefficients
1 parent e3e0d0d commit 6cad0b3

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

Orange/widgets/classify/tests/test_owsgd.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table
34
from Orange.widgets.classify.owsgd import OWSGD
45
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin, \
56
ParameterMapping
@@ -23,3 +24,16 @@ def setUp(self):
2324
ParameterMapping.from_attribute(self.widget, 'eta0'),
2425
ParameterMapping.from_attribute(self.widget, 'power_t'),
2526
]
27+
28+
def test_output_coefficients(self):
29+
"""Check if coefficients are on output after apply"""
30+
self.assertIsNone(self.get_output("Coefficients"))
31+
self.send_signal("Data", self.data)
32+
self.widget.apply_button.button.click()
33+
coeffs = self.get_output("Coefficients")
34+
self.assertIsInstance(coeffs, Table)
35+
domain = self.data.domain
36+
self.assertEqual(coeffs.X.shape, (len(domain.attributes) + 1,
37+
len(domain.class_var.values)))
38+
self.send_signal("Data", None)
39+
self.assertIsNone(self.get_output("Coefficients"))

Orange/widgets/model/owsgd.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from AnyQt.QtCore import Qt
44

55
from Orange.canvas.report import bool_str
6+
from Orange.data import ContinuousVariable, StringVariable, Domain, Table
67
from Orange.modelling.linear import SGDLearner
7-
from Orange.widgets import gui
8+
from Orange.widgets import gui, widget
9+
from Orange.widgets.classify.owlogisticregression import create_coef_table
810
from Orange.widgets.settings import Setting
911
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
1012

@@ -20,6 +22,8 @@ class OWSGD(OWBaseLearner):
2022

2123
LEARNER = SGDLearner
2224

25+
outputs = [("Coefficients", Table, widget.Explicit)]
26+
2327
reg_losses = (
2428
('Squared Loss', 'squared_loss'),
2529
('Huber', 'huber'),
@@ -281,6 +285,22 @@ def get_learner_parameters(self):
281285

282286
return list(params.items())
283287

288+
def update_model(self):
289+
super().update_model()
290+
coeffs = None
291+
if self.model is not None:
292+
if self.model.domain.class_var.is_discrete:
293+
coeffs = create_coef_table(self.model)
294+
else:
295+
attrs = [ContinuousVariable("coef", number_of_decimals=7)]
296+
domain = Domain(attrs, metas=[StringVariable("name")])
297+
cfs = list(self.model.intercept) + list(self.model.coefficients)
298+
names = ["intercept"] + \
299+
[attr.name for attr in self.model.domain.attributes]
300+
coeffs = Table(domain, list(zip(cfs, names)))
301+
coeffs.name = "coefficients"
302+
self.send("Coefficients", coeffs)
303+
284304

285305
if __name__ == '__main__':
286306
import sys

Orange/widgets/regression/tests/test_owsgdregression.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table
34
from Orange.widgets.regression.owsgdregression import OWSGD
45
from Orange.widgets.tests.base import WidgetTest, WidgetLearnerTestMixin, \
56
ParameterMapping
@@ -24,3 +25,15 @@ def setUp(self):
2425
ParameterMapping.from_attribute(self.widget, 'eta0'),
2526
ParameterMapping.from_attribute(self.widget, 'power_t'),
2627
]
28+
29+
def test_output_coefficients(self):
30+
"""Check if coefficients are on output after apply"""
31+
self.assertIsNone(self.get_output("Coefficients"))
32+
self.send_signal("Data", self.data)
33+
self.widget.apply_button.button.click()
34+
coeffs = self.get_output("Coefficients")
35+
self.assertIsInstance(coeffs, Table)
36+
self.assertEqual(coeffs.X.shape,
37+
(len(self.data.domain.attributes) + 1, 1))
38+
self.send_signal("Data", None)
39+
self.assertIsNone(self.get_output("Coefficients"))

0 commit comments

Comments
 (0)