Skip to content

Commit a660e7b

Browse files
committed
Fix: classification models output correct shapes
1 parent e0f241b commit a660e7b

File tree

5 files changed

+175
-41
lines changed

5 files changed

+175
-41
lines changed

Orange/base.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,55 @@ def backmap_probs(self, probs, n_values, backmappers):
320320
new_probs = new_probs / tots[:, None]
321321
return new_probs
322322

323+
def data_to_model_domain(self, data: Table) -> Table:
324+
"""
325+
Transforms data to the model domain if possible.
326+
327+
Parameters
328+
----------
329+
data
330+
Data to be transformed to the model domain
331+
332+
Returns
333+
-------
334+
Transformed data table
335+
336+
Raises
337+
------
338+
DomainTransformationError
339+
Error indicates that transformation is not possible since domains
340+
are not compatible
341+
"""
342+
if data.domain == self.domain:
343+
return data
344+
345+
if self.original_domain.attributes != data.domain.attributes \
346+
and data.X.size \
347+
and not all_nan(data.X):
348+
new_data = data.transform(self.original_domain)
349+
if all_nan(new_data.X):
350+
raise DomainTransformationError(
351+
"domain transformation produced no defined values")
352+
return new_data.transform(self.domain)
353+
354+
return data.transform(self.domain)
355+
323356
def __call__(self, data, ret=Value):
324357
multitarget = len(self.domain.class_vars) > 1
325358

326359
def one_hot_probs(value):
327360
if not multitarget:
328-
return one_hot(value)
361+
probs = one_hot(value)
362+
# sometimes a maximum class value(s) are not predicted and
363+
# those probs.shape[1] does not match to the length of values
364+
# in the domain - here we add missing dimension if required
365+
num_val = (
366+
len(self.original_domain.class_var.values)
367+
if self.original_domain is not None else probs.shape[0]
368+
)
369+
return np.hstack((
370+
probs, np.zeros((probs.shape[0], num_val - probs.shape[1]))
371+
))
329372

330373
max_card = max(len(c.values) for c in self.domain.class_vars)
331374
probs = np.zeros(value.shape + (max_card,), float)
@@ -336,21 +379,6 @@ def one_hot_probs(value):
336379
def fix_dim(x):
337380
return x[0] if one_d else x
338381

339-
def data_to_model_domain():
340-
if data.domain == self.domain:
341-
return data
342-
343-
if self.original_domain.attributes != data.domain.attributes \
344-
and data.X.size \
345-
and not all_nan(data.X):
346-
new_data = data.transform(self.original_domain)
347-
if all_nan(new_data.X):
348-
raise DomainTransformationError(
349-
"domain transformation produced no defined values")
350-
return new_data.transform(self.domain)
351-
352-
return data.transform(self.domain)
353-
354382
if not 0 <= ret <= 2:
355383
raise ValueError("invalid value of argument 'ret'")
356384
if ret > 0 and any(v.is_continuous for v in self.domain.class_vars):
@@ -368,14 +396,18 @@ def data_to_model_domain():
368396
else:
369397
one_d = False
370398

399+
# if sparse convert to csr_matrix
400+
if scipy.sparse.issparse(data):
401+
data = data.tocsr()
402+
371403
# Call the predictor
372404
backmappers = None
373405
n_values = []
374406
if isinstance(data, (np.ndarray, scipy.sparse.csr.csr_matrix)):
375407
prediction = self.predict(data)
376408
elif isinstance(data, Table):
377409
backmappers, n_values = self.get_backmappers(data)
378-
data = data_to_model_domain()
410+
data = self.data_to_model_domain(data)
379411
prediction = self.predict_storage(data)
380412
elif isinstance(data, (list, tuple)):
381413
data = Table.from_list(self.original_domain, data)

Orange/classification/softmax_regression.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def cost_grad(self, Theta_flat, X, Y):
7070

7171
return cost, grad.ravel()
7272

73-
def fit(self, X, y, W):
73+
def fit_storage(self, data):
74+
X, y = data.X, data.Y
75+
7476
if len(y.shape) > 1:
7577
raise ValueError('Softmax regression does not support '
7678
'multi-label classification')
@@ -81,7 +83,7 @@ def fit(self, X, y, W):
8183

8284
X = np.hstack((X, np.ones((X.shape[0], 1))))
8385

84-
self.num_classes = np.unique(y).size
86+
self.num_classes = len(data.domain.class_var.values)
8587
Y = np.eye(self.num_classes)[y.ravel().astype(int)]
8688

8789
theta = np.zeros(self.num_classes * X.shape[1])

Orange/classification/svm.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
1+
import warnings
2+
3+
import numpy as np
14
import sklearn.svm as skl_svm
25

3-
from Orange.classification import SklLearner, SklModel
6+
from Orange.classification import SklLearner, SklModel, MajorityLearner
7+
from Orange.evaluation import CrossValidation, CA
48
from Orange.preprocess import AdaptiveNormalize
59

610
__all__ = ["SVMLearner", "LinearSVMLearner", "NuSVMLearner"]
711

812
svm_pps = SklLearner.preprocessors + [AdaptiveNormalize()]
913

1014

11-
class SVMClassifier(SklModel):
12-
pass
13-
14-
1515
class SVMLearner(SklLearner):
1616
__wraps__ = skl_svm.SVC
17-
__returns__ = SVMClassifier
17+
__returns__ = SklModel
1818
preprocessors = svm_pps
1919

2020
def __init__(self, C=1.0, kernel='rbf', degree=3, gamma="auto",
@@ -24,6 +24,14 @@ def __init__(self, C=1.0, kernel='rbf', degree=3, gamma="auto",
2424
super().__init__(preprocessors=preprocessors)
2525
self.params = vars()
2626

27+
def __call__(self, data, progress_callback=None):
28+
if len(np.unique(data.Y)) > 1:
29+
return super().__call__(data, progress_callback)
30+
else:
31+
warnings.warn("Single class in data, returning Constant Model.")
32+
maj = MajorityLearner()
33+
return maj(data)
34+
2735

2836
class LinearSVMLearner(SklLearner):
2937
__wraps__ = skl_svm.LinearSVC
@@ -36,20 +44,18 @@ def __init__(self, penalty='l2', loss='squared_hinge', dual=True,
3644
super().__init__(preprocessors=preprocessors)
3745
self.params = vars()
3846

39-
40-
class NuSVMClassifier(SklModel):
41-
42-
def predict(self, X):
43-
value = self.skl_model.predict(X)
44-
if self.skl_model.probability:
45-
prob = self.skl_model.predict_proba(X)
46-
return value, prob
47-
return value
47+
def __call__(self, data, progress_callback=None):
48+
if len(np.unique(data.Y)) > 1:
49+
return super().__call__(data, progress_callback)
50+
else:
51+
warnings.warn("Single class in data, returning Constant Model.")
52+
maj = MajorityLearner()
53+
return maj(data)
4854

4955

5056
class NuSVMLearner(SklLearner):
5157
__wraps__ = skl_svm.NuSVC
52-
__returns__ = NuSVMClassifier
58+
__returns__ = SklModel
5359
preprocessors = svm_pps
5460

5561
def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma="auto", coef0=0.0,
@@ -58,12 +64,22 @@ def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma="auto", coef0=0.0,
5864
super().__init__(preprocessors=preprocessors)
5965
self.params = vars()
6066

67+
def __call__(self, data, progress_callback=None):
68+
if len(np.unique(data.Y)) > 1:
69+
return super().__call__(data, progress_callback)
70+
else:
71+
warnings.warn("Single class in data, returning Constant Model.")
72+
maj = MajorityLearner()
73+
return maj(data)
74+
6175

6276
if __name__ == '__main__':
6377
import Orange
6478

65-
data = Orange.data.Table('iris')
79+
data_ = Orange.data.Table('iris')
6680
learners = [SVMLearner(), NuSVMLearner(), LinearSVMLearner()]
67-
res = Orange.evaluation.CrossValidation(data, learners)
68-
for l, ca in zip(learners, Orange.evaluation.CA(res)):
81+
cv = CrossValidation()
82+
res = cv(data_, learners)
83+
ca = CA()
84+
for l, ca in zip(learners, ca(res)):
6985
print("learner: {}\nCA: {}\n".format(l, ca))

Orange/classification/tests/test_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,17 @@ def test_no_common_values(self):
186186
self.assertTrue(np.all(val <= 2))
187187
np.testing.assert_array_equal(prob, 1 / 3)
188188

189+
def test_sparse_matrix(self):
190+
iris_sparse = self.iris.to_sparse()
191+
for lrn in [TreeLearner, LogisticRegressionLearner]: # skl and non-skl
192+
model = lrn()(iris_sparse)
193+
pred = model(iris_sparse.X.tocsc())
194+
self.assertTupleEqual((len(self.iris),), pred.shape)
195+
pred = model(iris_sparse.X.tolil())
196+
self.assertTupleEqual((len(self.iris),), pred.shape)
197+
pred = model(iris_sparse.X.tocoo())
198+
self.assertTupleEqual((len(self.iris),), pred.shape)
199+
189200

190201
if __name__ == '__main__':
191202
unittest.main()

Orange/tests/test_classification.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
import Orange.classification
1919
from Orange.classification import (
2020
Learner, Model,
21-
NaiveBayesLearner, LogisticRegressionLearner, NuSVMLearner, MajorityLearner,
21+
NaiveBayesLearner, LogisticRegressionLearner, NuSVMLearner,
22+
MajorityLearner,
2223
RandomForestLearner, SimpleTreeLearner, SoftmaxRegressionLearner,
2324
SVMLearner, LinearSVMLearner, OneClassSVMLearner, TreeLearner, KNNLearner,
24-
SimpleRandomForestLearner, EllipticEnvelopeLearner)
25+
SimpleRandomForestLearner, EllipticEnvelopeLearner, SklTreeLearner,
26+
NNClassificationLearner, CN2Learner, CN2UnorderedLearner,
27+
CN2SDUnorderedLearner)
2528
from Orange.classification.rules import _RuleLearner
2629
from Orange.data import (ContinuousVariable, DiscreteVariable,
27-
Domain, Table, Variable)
30+
Domain, Table)
2831
from Orange.data.table import DomainTransformationError
2932
from Orange.evaluation import CrossValidation
3033
from Orange.tests.dummy_learners import DummyLearner, DummyMulticlassLearner
@@ -190,6 +193,76 @@ def test_incompatible_domain(self):
190193
with self.assertRaises(DomainTransformationError):
191194
clf(titanic)
192195

196+
LEARNERS_CLASSIFICATION = [
197+
LogisticRegressionLearner(),
198+
RandomForestLearner(),
199+
SimpleRandomForestLearner(),
200+
SoftmaxRegressionLearner(),
201+
KNNLearner(),
202+
NaiveBayesLearner(),
203+
SVMLearner(probability=True),
204+
SVMLearner(),
205+
LinearSVMLearner(),
206+
NuSVMLearner(probability=True),
207+
NuSVMLearner(),
208+
TreeLearner(),
209+
SklTreeLearner(),
210+
SimpleTreeLearner(),
211+
MajorityLearner(),
212+
NNClassificationLearner(),
213+
CN2Learner(),
214+
CN2UnorderedLearner(),
215+
CN2SDUnorderedLearner()
216+
]
217+
218+
def test_result_shape(self):
219+
"""
220+
Test if the results shapes are correct
221+
"""
222+
iris = Table('iris')
223+
for learner in self.LEARNERS_CLASSIFICATION:
224+
with self.subTest(learner.name):
225+
# model trained on only one value (but three in the domain)
226+
model = learner(iris[0:50])
227+
228+
res = model(iris[0:50])
229+
self.assertTupleEqual((50, ), res.shape)
230+
231+
# probabilities must still be for three classes
232+
res = model(iris[0:50], model.Probs)
233+
self.assertTupleEqual((50, 3), res.shape)
234+
235+
# model trained on all classes and predicting with one class
236+
model = learner(iris)
237+
res = model(iris[0:50], model.Probs)
238+
self.assertTupleEqual((50, 3), res.shape)
239+
240+
def test_result_shape_numpy(self):
241+
"""
242+
Test whether results shapes are correct when testing on numpy data
243+
"""
244+
iris = Table('iris')
245+
for learner in self.LEARNERS_CLASSIFICATION:
246+
with self.subTest(learner.name):
247+
model = learner(iris)
248+
transformed_iris = model.data_to_model_domain(iris)
249+
250+
res = model(transformed_iris.X[0:5])
251+
self.assertTupleEqual((5, ), res.shape)
252+
253+
res = model(transformed_iris.X[0:1], model.Probs)
254+
self.assertTupleEqual((1, 3), res.shape)
255+
256+
def test_fit_one_class(self):
257+
"""
258+
Test whether the fitting with one class only pass - before it failed
259+
for some models.
260+
"""
261+
iris = Table('iris')
262+
for learner in self.LEARNERS_CLASSIFICATION:
263+
with self.subTest(learner.name):
264+
learner(iris[0:50])
265+
193266

194267
class ExpandProbabilitiesTest(unittest.TestCase):
195268
def prepareTable(self, rows, attr, vars, class_var_domain):

0 commit comments

Comments
 (0)