Skip to content

Commit 6790a6c

Browse files
authored
Merge pull request #4780 from pavlin-policar/multinomial
logistic regression: sync defaults with scikit-learn
2 parents 6169237 + 7b3130a commit 6790a6c

File tree

5 files changed

+98
-57
lines changed

5 files changed

+98
-57
lines changed

Orange/classification/logistic_regression.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,22 @@ class LogisticRegressionLearner(SklLearner, _FeatureScorerMixin):
3636

3737
def __init__(self, penalty="l2", dual=False, tol=0.0001, C=1.0,
3838
fit_intercept=True, intercept_scaling=1, class_weight=None,
39-
random_state=None, solver='liblinear', max_iter=100,
40-
multi_class='ovr', verbose=0, n_jobs=1, preprocessors=None):
39+
random_state=None, solver="auto", max_iter=100,
40+
multi_class="auto", verbose=0, n_jobs=1, preprocessors=None):
4141
super().__init__(preprocessors=preprocessors)
4242
self.params = vars()
43+
44+
def _initialize_wrapped(self):
45+
params = self.params.copy()
46+
# The default scikit-learn solver `lbfgs` (v0.22) does not support the
47+
# l1 penalty.
48+
solver, penalty = params.pop("solver"), params.get("penalty")
49+
if solver == "auto":
50+
if penalty == "l1":
51+
solver = "liblinear"
52+
else:
53+
solver = "lbfgs"
54+
params["solver"] = solver
55+
56+
return self.__wraps__(**params)
57+

Orange/tests/test_evaluation_scoring.py

Lines changed: 36 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,12 @@ def setUpClass(cls):
5858
def test_precision_iris(self):
5959
learner = LogisticRegressionLearner(preprocessors=[])
6060
res = TestOnTrainingData()(self.iris, [learner])
61-
self.assertAlmostEqual(self.score(res, average='weighted')[0],
62-
0.96189, 5)
63-
self.assertAlmostEqual(self.score(res, target=1)[0], 0.97826, 5)
64-
self.assertAlmostEqual(self.score(res, target=1, average=None)[0],
65-
0.97826, 5)
66-
self.assertAlmostEqual(self.score(res, target=1, average='weighted')[0],
67-
0.97826, 5)
68-
self.assertAlmostEqual(self.score(res, target=0, average=None)[0], 1, 5)
69-
self.assertAlmostEqual(self.score(res, target=2, average=None)[0],
70-
0.90741, 5)
61+
self.assertGreater(self.score(res, average='weighted')[0], 0.95)
62+
self.assertGreater(self.score(res, target=1)[0], 0.95)
63+
self.assertGreater(self.score(res, target=1, average=None)[0], 0.95)
64+
self.assertGreater(self.score(res, target=1, average='weighted')[0], 0.95)
65+
self.assertGreater(self.score(res, target=0, average=None)[0], 0.99)
66+
self.assertGreater(self.score(res, target=2, average=None)[0], 0.94)
7167

7268
def test_precision_multiclass(self):
7369
results = Results(
@@ -117,15 +113,12 @@ def setUpClass(cls):
117113
def test_recall_iris(self):
118114
learner = LogisticRegressionLearner(preprocessors=[])
119115
res = TestOnTrainingData()(self.iris, [learner])
120-
self.assertAlmostEqual(self.score(res, average='weighted')[0], 0.96, 5)
121-
self.assertAlmostEqual(self.score(res, target=1)[0], 0.9, 5)
122-
self.assertAlmostEqual(self.score(res, target=1, average=None)[0],
123-
0.9, 5)
124-
self.assertAlmostEqual(self.score(res, target=1, average='weighted')[0],
125-
0.9, 5)
126-
self.assertAlmostEqual(self.score(res, target=0, average=None)[0], 1, 5)
127-
self.assertAlmostEqual(self.score(res, target=2, average=None)[0],
128-
0.98, 5)
116+
self.assertGreater(self.score(res, average='weighted')[0], 0.96)
117+
self.assertGreater(self.score(res, target=1)[0], 0.9)
118+
self.assertGreater(self.score(res, target=1, average=None)[0], 0.9)
119+
self.assertGreater(self.score(res, target=1, average='weighted')[0], 0.9)
120+
self.assertGreater(self.score(res, target=0, average=None)[0], 0.99)
121+
self.assertGreater(self.score(res, target=2, average=None)[0], 0.97)
129122

130123
def test_recall_multiclass(self):
131124
results = Results(
@@ -175,16 +168,12 @@ def setUpClass(cls):
175168
def test_recall_iris(self):
176169
learner = LogisticRegressionLearner(preprocessors=[])
177170
res = TestOnTrainingData()(self.iris, [learner])
178-
self.assertAlmostEqual(self.score(res, average='weighted')[0],
179-
0.959935, 5)
180-
self.assertAlmostEqual(self.score(res, target=1)[0], 0.9375, 5)
181-
self.assertAlmostEqual(self.score(res, target=1, average=None)[0],
182-
0.9375, 5)
183-
self.assertAlmostEqual(self.score(res, target=1, average='weighted')[0],
184-
0.9375, 5)
185-
self.assertAlmostEqual(self.score(res, target=0, average=None)[0], 1, 5)
186-
self.assertAlmostEqual(self.score(res, target=2, average=None)[0],
187-
0.942307, 5)
171+
self.assertGreater(self.score(res, average='weighted')[0], 0.95)
172+
self.assertGreater(self.score(res, target=1)[0], 0.95)
173+
self.assertGreater(self.score(res, target=1, average=None)[0], 0.95)
174+
self.assertGreater(self.score(res, target=1, average='weighted')[0], 0.95)
175+
self.assertGreater(self.score(res, target=0, average=None)[0], 0.99)
176+
self.assertGreater(self.score(res, target=2, average=None)[0], 0.95)
188177

189178
def test_F1_multiclass(self):
190179
results = Results(
@@ -377,16 +366,24 @@ def setUpClass(cls):
377366
def test_specificity_iris(self):
378367
learner = LogisticRegressionLearner(preprocessors=[])
379368
res = TestOnTrainingData()(self.iris, [learner])
380-
self.assertAlmostEqual(self.score(res, average='weighted')[0],
381-
(1 + 0.99 + 0.95) / 3, 5)
382-
self.assertAlmostEqual(self.score(res, target=1)[0], 99 / (99 + 1), 5)
383-
self.assertAlmostEqual(self.score(res, target=1, average=None)[0],
384-
99 / (99 + 1), 5)
385-
self.assertAlmostEqual(self.score(res, target=1, average='weighted')[0],
386-
99 / (99 + 1), 5)
387-
self.assertAlmostEqual(self.score(res, target=0, average=None)[0], 1, 5)
388-
self.assertAlmostEqual(self.score(res, target=2, average=None)[0],
389-
95 / (95 + 5), 5)
369+
self.assertGreaterEqual(
370+
self.score(res, average='weighted')[0], (1 + 0.99 + 0.95) / 3
371+
)
372+
self.assertGreaterEqual(
373+
self.score(res, target=1)[0], 99 / (99 + 1)
374+
)
375+
self.assertGreaterEqual(
376+
self.score(res, target=1, average=None)[0], 99 / (99 + 1)
377+
)
378+
self.assertGreaterEqual(
379+
self.score(res, target=1, average='weighted')[0], 99 / (99 + 1)
380+
)
381+
self.assertGreaterEqual(
382+
self.score(res, target=0, average=None)[0], 1
383+
)
384+
self.assertGreaterEqual(
385+
self.score(res, target=2, average=None)[0], 95 / (95 + 5)
386+
)
390387

391388
def test_precision_multiclass(self):
392389
results = Results(

Orange/tests/test_logistic_regression.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_probability(self):
6666
def test_learner_scorer(self):
6767
learner = LogisticRegressionLearner()
6868
scores = learner.score_data(self.heart_disease)
69-
self.assertEqual('major vessels colored',
69+
self.assertEqual('chest pain',
7070
self.heart_disease.domain.attributes[np.argmax(scores)].name)
7171
self.assertEqual(scores.shape, (1, len(self.heart_disease.domain.attributes)))
7272

@@ -89,13 +89,13 @@ def test_learner_scorer_multiclass(self):
8989
attr = self.zoo.domain.attributes
9090
learner = LogisticRegressionLearner()
9191
scores = learner.score_data(self.zoo)
92-
self.assertEqual('aquatic', attr[np.argmax(scores[0])].name) # amphibian
92+
self.assertEqual('legs', attr[np.argmax(scores[0])].name) # amphibian
9393
self.assertEqual('feathers', attr[np.argmax(scores[1])].name) # bird
9494
self.assertEqual('fins', attr[np.argmax(scores[2])].name) # fish
9595
self.assertEqual('legs', attr[np.argmax(scores[3])].name) # insect
9696
self.assertEqual('backbone', attr[np.argmax(scores[4])].name) # invertebrate
9797
self.assertEqual('milk', attr[np.argmax(scores[5])].name) # mammal
98-
self.assertEqual('hair', attr[np.argmax(scores[6])].name) # reptile
98+
self.assertEqual('aquatic', attr[np.argmax(scores[6])].name) # reptile
9999
self.assertEqual(scores.shape,
100100
(len(self.zoo.domain.class_var.values), len(attr)))
101101

@@ -131,3 +131,23 @@ def test_sklearn_single_class(self):
131131
self.assertEqual(len(np.unique(t.Y)), 1)
132132
lr = sklearn.linear_model.LogisticRegression()
133133
self.assertRaises(ValueError, lr.fit, t.X, t.Y)
134+
135+
def test_auto_solver(self):
136+
# These defaults are valid as of sklearn v0.23.0
137+
# lbfgs is default for l2 penalty
138+
lr = LogisticRegressionLearner(penalty="l2", solver="auto")
139+
skl_clf = lr._initialize_wrapped()
140+
self.assertEqual(skl_clf.solver, "lbfgs")
141+
self.assertEqual(skl_clf.penalty, "l2")
142+
143+
# lbfgs is default for no penalty
144+
lr = LogisticRegressionLearner(penalty=None, solver="auto")
145+
skl_clf = lr._initialize_wrapped()
146+
self.assertEqual(skl_clf.solver, "lbfgs")
147+
self.assertEqual(skl_clf.penalty, None)
148+
149+
# liblinear is default for l2 penalty
150+
lr = LogisticRegressionLearner(penalty="l1", solver="auto")
151+
skl_clf = lr._initialize_wrapped()
152+
self.assertEqual(skl_clf.solver, "liblinear")
153+
self.assertEqual(skl_clf.penalty, "l1")

Orange/widgets/evaluate/tests/test_owtestandscore.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -382,26 +382,33 @@ def test_scores_log_reg_bad2(self):
382382
table_test = Table.from_list(
383383
self.scores_domain,
384384
list(zip(*(self.scores_table_values + [list("yynn")]))))
385-
self.assertTupleEqual(self._test_scores(
386-
table_train, table_test, LogisticRegressionLearner(),
387-
OWTestAndScore.TestOnTest, None),
388-
(0, 0, 0, 0, 0))
385+
386+
lr = LogisticRegressionLearner()
387+
np.testing.assert_almost_equal(
388+
self._test_scores(
389+
table_train, table_test, lr, OWTestAndScore.TestOnTest, None
390+
),
391+
(0, 0.25, 0.2, 0.1666666, 0.25),
392+
)
389393

390394
def test_scores_log_reg_advanced(self):
391395
table_train = Table.from_list(
392-
self.scores_domain, list(zip(
393-
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yyynn"))
396+
self.scores_domain,
397+
list(zip([1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yyynn"))
394398
)
395399
table_test = Table.from_list(
396-
self.scores_domain, list(zip(
397-
[1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
400+
self.scores_domain,
401+
list(zip([1, 1, 1.23, 23.8, 5.], [1., 2., 3., 4., 3.], "yynnn"))
398402
)
399403

404+
lr = LogisticRegressionLearner()
405+
np.testing.assert_
400406
np.testing.assert_almost_equal(
401-
self._test_scores(table_train, table_test,
402-
LogisticRegressionLearner(),
403-
OWTestAndScore.TestOnTest, None),
404-
(2 / 3, 0.8, 0.8, 13 / 15, 0.8))
407+
self._test_scores(
408+
table_train, table_test, lr, OWTestAndScore.TestOnTest, None
409+
),
410+
(1, 0.8, 0.8, 13 / 15, 0.8)
411+
)
405412

406413
def test_scores_cross_validation(self):
407414
"""

Orange/widgets/visualize/tests/test_ownomogram.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def test_nomogram_nb_multiclass(self):
9797
def test_nomogram_lr_multiclass(self):
9898
"""Check probabilities for logistic regression classifier for various
9999
values of classes and radio buttons for multiclass data"""
100-
cls = LogisticRegressionLearner()(self.lenses)
100+
cls = LogisticRegressionLearner(
101+
multi_class="ovr", solver="liblinear"
102+
)(self.lenses)
101103
self._test_helper(cls, [9, 45, 52])
102104

103105
def test_nomogram_with_instance_nb(self):

0 commit comments

Comments
 (0)