Skip to content

Commit 7a6c034

Browse files
authored
Merge pull request #1448 from sstanovnik/fix-logreg-weights
[FIX] Stop advertising support for weights in LogisticRegression.
2 parents d6a324e + c5d211b commit 7a6c034

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

Orange/classification/logistic_regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ def __init__(self, penalty="l2", dual=False, tol=0.0001, C=1.0,
4141
random_state=None, preprocessors=None):
4242
super().__init__(preprocessors=preprocessors)
4343
self.params = vars()
44+
45+
@property
46+
def supports_weights(self):
47+
# liblinear (default) cannot handle weights
48+
return False

Orange/tests/test_base.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
import unittest
44

55
from Orange.base import SklLearner
6+
from Orange.regression import LinearRegressionLearner
67
from Orange.classification import LogisticRegressionLearner
8+
from Orange.data import Table
9+
10+
from sklearn.linear_model import LogisticRegression
711

812

913
class TestSklLearner(unittest.TestCase):
@@ -26,7 +30,19 @@ class DummyLearner(SklLearner):
2630

2731
self.assertFalse(DummyLearner().supports_weights)
2832

29-
def test_logreg(self):
30-
self.assertTrue(LogisticRegressionLearner().supports_weights,
31-
"Either LogisticRegression no longer supports weighted tables"
33+
def test_linreg(self):
34+
self.assertTrue(LinearRegressionLearner().supports_weights,
35+
"Either LinearRegression no longer supports weighted tables "
3236
"or SklLearner.supports_weights is out-of-date.")
37+
38+
def test_logreg(self):
39+
self.assertFalse(LogisticRegressionLearner().supports_weights,
40+
"Logistic regression has its supports_weights overridden because "
41+
"liblinear doesn't support them (even though the parameter exists)")
42+
43+
def test_assert_liblinear_doesnt_accept_weights(self):
44+
data = Table('iris')
45+
data.set_weights(1.2)
46+
with self.assertRaises(ValueError):
47+
skl = LogisticRegression(solver='liblinear')
48+
skl.fit(data.X, data.Y, data.W)

0 commit comments

Comments
 (0)