Skip to content

Commit 2fc3112

Browse files
committed
Models: output probabilities with correct shape
1 parent cf781b1 commit 2fc3112

File tree

3 files changed

+94
-52
lines changed

3 files changed

+94
-52
lines changed

Orange/classification/softmax_regression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def cost_grad(self, Theta_flat, X, Y):
7070

7171
return cost, grad.ravel()
7272

73-
def fit(self, X, y, W):
73+
def fit_storage(self, data):
74+
X, y = data.X, data.Y
75+
7476
if len(y.shape) > 1:
7577
raise ValueError('Softmax regression does not support '
7678
'multi-label classification')
@@ -81,7 +83,7 @@ def fit(self, X, y, W):
8183

8284
X = np.hstack((X, np.ones((X.shape[0], 1))))
8385

84-
self.num_classes = np.unique(y).size
86+
self.num_classes = len(data.domain.class_var.values)
8587
Y = np.eye(self.num_classes)[y.ravel().astype(int)]
8688

8789
theta = np.zeros(self.num_classes * X.shape[1])

Orange/classification/svm.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
class SVMLearner(SklLearner):
1212
__wraps__ = skl_svm.SVC
13+
__returns__ = SklModel
1314
preprocessors = svm_pps
1415

1516
def __init__(self, C=1.0, kernel='rbf', degree=3, gamma="auto",
@@ -32,19 +33,9 @@ def __init__(self, penalty='l2', loss='squared_hinge', dual=True,
3233
self.params = vars()
3334

3435

35-
class NuSVMClassifier(SklModel):
36-
37-
def predict(self, X):
38-
value = self.skl_model.predict(X)
39-
if self.skl_model.probability:
40-
prob = self.skl_model.predict_proba(X)
41-
return value, prob
42-
return value
43-
44-
4536
class NuSVMLearner(SklLearner):
4637
__wraps__ = skl_svm.NuSVC
47-
__returns__ = NuSVMClassifier
38+
__returns__ = SklModel
4839
preprocessors = svm_pps
4940

5041
def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma="auto", coef0=0.0,
@@ -55,10 +46,11 @@ def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma="auto", coef0=0.0,
5546

5647

5748
if __name__ == '__main__':
58-
import Orange
49+
from Orange.evaluation import CrossValidation, CA
50+
from Orange.data import Table
5951

60-
data = Orange.data.Table('iris')
52+
data_ = Table('iris')
6153
learners = [SVMLearner(), NuSVMLearner(), LinearSVMLearner()]
62-
res = Orange.evaluation.CrossValidation(data, learners)
63-
for l, ca in zip(learners, Orange.evaluation.CA(res)):
54+
res = CrossValidation()(data_, learners)
55+
for l, ca in zip(learners, CA()(res)):
6456
print("learner: {}\nCA: {}\n".format(l, ca))

Orange/tests/test_classification.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
MajorityLearner,
2323
RandomForestLearner, SimpleTreeLearner, SoftmaxRegressionLearner,
2424
SVMLearner, LinearSVMLearner, OneClassSVMLearner, TreeLearner, KNNLearner,
25-
SimpleRandomForestLearner, EllipticEnvelopeLearner,
26-
SGDClassificationLearner)
25+
SimpleRandomForestLearner, EllipticEnvelopeLearner)
2726
from Orange.classification.rules import _RuleLearner
2827
from Orange.data import (ContinuousVariable, DiscreteVariable,
2928
Domain, Table)
@@ -33,6 +32,24 @@
3332
from Orange.tests import test_filename
3433

3534

35+
def all_learners():
36+
classification_modules = pkgutil.walk_packages(
37+
path=Orange.classification.__path__,
38+
prefix="Orange.classification.",
39+
onerror=lambda x: None)
40+
for importer, modname, ispkg in classification_modules:
41+
try:
42+
module = pkgutil.importlib.import_module(modname)
43+
except ImportError:
44+
continue
45+
46+
for name, class_ in inspect.getmembers(module, inspect.isclass):
47+
if (issubclass(class_, Learner) and
48+
not name.startswith('_') and
49+
'base' not in class_.__module__):
50+
yield class_
51+
52+
3653
class MultiClassTest(unittest.TestCase):
3754
def test_unsupported(self):
3855
nrows = 20
@@ -194,21 +211,70 @@ def test_incompatible_domain(self):
194211

195212
def test_result_shape(self):
196213
"""
197-
This test function will be extended for all models in on of the
198-
following pull requests.
214+
Test if the results shapes are correct
199215
"""
200216
iris = Table('iris')
201-
learner = SGDClassificationLearner()
217+
for learner in all_learners():
218+
with self.subTest(learner.__name__):
219+
# model trained on only one value (but three in the domain)
220+
try:
221+
model = learner()(iris[0:50])
222+
except TypeError:
223+
# cannot be tested with default parameters
224+
continue
225+
226+
res = model(iris[0:50])
227+
self.assertTupleEqual((50,), res.shape)
228+
229+
# probabilities must still be for three classes
230+
res = model(iris[0:50], model.Probs)
231+
self.assertTupleEqual((50, 3), res.shape)
202232

203-
# model trained on only one value (but three in the domain)
204-
model = learner(iris)
233+
# model trained on all classes and predicting with one class
234+
try:
235+
model = learner()(iris[0:50])
236+
except TypeError:
237+
# cannot be tested with default parameters
238+
continue
239+
res = model(iris[0:50], model.Probs)
240+
self.assertTupleEqual((50, 3), res.shape)
205241

206-
res = model(iris[0:50])
207-
self.assertTupleEqual((50,), res.shape)
242+
def test_result_shape_numpy(self):
243+
"""
244+
Test whether results shapes are correct when testing on numpy data
245+
"""
246+
iris = Table('iris')
247+
for learner in all_learners():
248+
with self.subTest(learner.__name__):
249+
if learner.__name__ == "CN2SDLearner":
250+
# TODO: fix CN2SDLearner
251+
continue
252+
try:
253+
model = learner()(iris)
254+
except TypeError:
255+
# cannot be tested with default parameters
256+
continue
257+
transformed_iris = model.data_to_model_domain(iris)
208258

209-
# probabilities must still be for three classes
210-
res = model(iris[0:50], model.Probs)
211-
self.assertTupleEqual((50, 3), res.shape)
259+
res = model(transformed_iris.X[0:5])
260+
self.assertTupleEqual((5,), res.shape)
261+
262+
res = model(transformed_iris.X[0:1], model.Probs)
263+
self.assertTupleEqual((1, 3), res.shape)
264+
265+
def test_fit_one_class(self):
266+
"""
267+
Test whether the fitting with one class only pass - before it failed
268+
for some models.
269+
"""
270+
iris = Table('iris')
271+
for learner in all_learners():
272+
with self.subTest(learner.__name__):
273+
try:
274+
model = learner()(iris[:50])
275+
except TypeError:
276+
# cannot be tested with default parameters
277+
continue
212278

213279

214280
class ExpandProbabilitiesTest(unittest.TestCase):
@@ -309,7 +375,7 @@ def test_unknown(self):
309375

310376
def test_missing_class(self):
311377
table = Table(test_filename("datasets/adult_sample_missing"))
312-
for learner in LearnerAccessibility().all_learners():
378+
for learner in all_learners():
313379
try:
314380
learner = learner()
315381
if isinstance(learner, NuSVMLearner):
@@ -330,33 +396,15 @@ def setUp(self):
330396
# Convergence warnings are irrelevant for these tests
331397
warnings.filterwarnings("ignore", ".*", ConvergenceWarning)
332398

333-
334-
def all_learners(self):
335-
classification_modules = pkgutil.walk_packages(
336-
path=Orange.classification.__path__,
337-
prefix="Orange.classification.",
338-
onerror=lambda x: None)
339-
for importer, modname, ispkg in classification_modules:
340-
try:
341-
module = pkgutil.importlib.import_module(modname)
342-
except ImportError:
343-
continue
344-
345-
for name, class_ in inspect.getmembers(module, inspect.isclass):
346-
if (issubclass(class_, Learner) and
347-
not name.startswith('_') and
348-
'base' not in class_.__module__):
349-
yield class_
350-
351399
def test_all_learners_accessible_in_Orange_classification_namespace(self):
352-
for learner in self.all_learners():
400+
for learner in all_learners():
353401
if not hasattr(Orange.classification, learner.__name__):
354402
self.fail("%s is not visible in Orange.classification"
355403
" namespace" % learner.__name__)
356404

357405
def test_all_models_work_after_unpickling(self):
358406
datasets = [Table('iris'), Table('titanic')]
359-
for learner in list(self.all_learners()):
407+
for learner in list(all_learners()):
360408
try:
361409
learner = learner()
362410
except Exception:
@@ -381,7 +429,7 @@ def test_all_models_work_after_unpickling(self):
381429
% (learner.__class__.__name__, ds.name))
382430

383431
def test_adequacy_all_learners(self):
384-
for learner in self.all_learners():
432+
for learner in all_learners():
385433
try:
386434
learner = learner()
387435
table = Table("housing")
@@ -391,7 +439,7 @@ def test_adequacy_all_learners(self):
391439
continue
392440

393441
def test_adequacy_all_learners_multiclass(self):
394-
for learner in self.all_learners():
442+
for learner in all_learners():
395443
try:
396444
learner = learner()
397445
table = Table(test_filename("datasets/test8.tab"))

0 commit comments

Comments
 (0)