Skip to content

Commit 1d00add

Browse files
committed
Fix: classification models output correct shapes
1 parent e0f241b commit 1d00add

File tree

4 files changed

+52
-34
lines changed

4 files changed

+52
-34
lines changed

Orange/base.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,39 @@ def backmap_probs(self, probs, n_values, backmappers):
320320
new_probs = new_probs / tots[:, None]
321321
return new_probs
322322

323+
def data_to_model_domain(self, data: Table) -> Table:
324+
"""
325+
Transforms data to the model domain if possible.
326+
327+
Parameters
328+
----------
329+
data
330+
Data to be transformed to the model domain
331+
332+
Returns
333+
-------
334+
Transformed data table
335+
336+
Raises
337+
------
338+
DomainTransformationError
339+
Error indicates that transformation is not possible since domains
340+
are not compatible
341+
"""
342+
if data.domain == self.domain:
343+
return data
344+
345+
if self.original_domain.attributes != data.domain.attributes \
346+
and data.X.size \
347+
and not all_nan(data.X):
348+
new_data = data.transform(self.original_domain)
349+
if all_nan(new_data.X):
350+
raise DomainTransformationError(
351+
"domain transformation produced no defined values")
352+
return new_data.transform(self.domain)
353+
354+
return data.transform(self.domain)
355+
323356
def __call__(self, data, ret=Value):
324357
multitarget = len(self.domain.class_vars) > 1
325358

@@ -336,21 +369,6 @@ def one_hot_probs(value):
336369
def fix_dim(x):
337370
return x[0] if one_d else x
338371

339-
def data_to_model_domain():
340-
if data.domain == self.domain:
341-
return data
342-
343-
if self.original_domain.attributes != data.domain.attributes \
344-
and data.X.size \
345-
and not all_nan(data.X):
346-
new_data = data.transform(self.original_domain)
347-
if all_nan(new_data.X):
348-
raise DomainTransformationError(
349-
"domain transformation produced no defined values")
350-
return new_data.transform(self.domain)
351-
352-
return data.transform(self.domain)
353-
354372
if not 0 <= ret <= 2:
355373
raise ValueError("invalid value of argument 'ret'")
356374
if ret > 0 and any(v.is_continuous for v in self.domain.class_vars):
@@ -368,14 +386,18 @@ def data_to_model_domain():
368386
else:
369387
one_d = False
370388

389+
# if sparse convert to csr_matrix
390+
if scipy.sparse.issparse(data):
391+
data = data.tocsr()
392+
371393
# Call the predictor
372394
backmappers = None
373395
n_values = []
374396
if isinstance(data, (np.ndarray, scipy.sparse.csr.csr_matrix)):
375397
prediction = self.predict(data)
376398
elif isinstance(data, Table):
377399
backmappers, n_values = self.get_backmappers(data)
378-
data = data_to_model_domain()
400+
data = self.data_to_model_domain(data)
379401
prediction = self.predict_storage(data)
380402
elif isinstance(data, (list, tuple)):
381403
data = Table.from_list(self.original_domain, data)

Orange/classification/svm.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,8 @@
88
svm_pps = SklLearner.preprocessors + [AdaptiveNormalize()]
99

1010

11-
class SVMClassifier(SklModel):
12-
pass
13-
14-
1511
class SVMLearner(SklLearner):
1612
__wraps__ = skl_svm.SVC
17-
__returns__ = SVMClassifier
1813
preprocessors = svm_pps
1914

2015
def __init__(self, C=1.0, kernel='rbf', degree=3, gamma="auto",
@@ -37,19 +32,8 @@ def __init__(self, penalty='l2', loss='squared_hinge', dual=True,
3732
self.params = vars()
3833

3934

40-
class NuSVMClassifier(SklModel):
41-
42-
def predict(self, X):
43-
value = self.skl_model.predict(X)
44-
if self.skl_model.probability:
45-
prob = self.skl_model.predict_proba(X)
46-
return value, prob
47-
return value
48-
49-
5035
class NuSVMLearner(SklLearner):
5136
__wraps__ = skl_svm.NuSVC
52-
__returns__ = NuSVMClassifier
5337
preprocessors = svm_pps
5438

5539
def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma="auto", coef0=0.0,

Orange/classification/tests/test_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,17 @@ def test_no_common_values(self):
186186
self.assertTrue(np.all(val <= 2))
187187
np.testing.assert_array_equal(prob, 1 / 3)
188188

189+
def test_sparse_matrix(self):
190+
iris_sparse = self.iris.to_sparse()
191+
for lrn in [LogisticRegressionLearner, TreeLearner]: # skl and non-skl
192+
model = lrn()(iris_sparse)
193+
pred = model(iris_sparse.X.tocsc())
194+
self.assertTupleEqual((len(self.iris),), pred.shape)
195+
pred = model(iris_sparse.X.tolil())
196+
self.assertTupleEqual((len(self.iris),), pred.shape)
197+
pred = model(iris_sparse.X.tocoo())
198+
self.assertTupleEqual((len(self.iris),), pred.shape)
199+
189200

190201
if __name__ == '__main__':
191202
unittest.main()

Orange/tests/test_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import Orange.classification
1919
from Orange.classification import (
2020
Learner, Model,
21-
NaiveBayesLearner, LogisticRegressionLearner, NuSVMLearner, MajorityLearner,
21+
NaiveBayesLearner, LogisticRegressionLearner, NuSVMLearner,
22+
MajorityLearner,
2223
RandomForestLearner, SimpleTreeLearner, SoftmaxRegressionLearner,
2324
SVMLearner, LinearSVMLearner, OneClassSVMLearner, TreeLearner, KNNLearner,
2425
SimpleRandomForestLearner, EllipticEnvelopeLearner)
2526
from Orange.classification.rules import _RuleLearner
2627
from Orange.data import (ContinuousVariable, DiscreteVariable,
27-
Domain, Table, Variable)
28+
Domain, Table)
2829
from Orange.data.table import DomainTransformationError
2930
from Orange.evaluation import CrossValidation
3031
from Orange.tests.dummy_learners import DummyLearner, DummyMulticlassLearner

0 commit comments

Comments
 (0)