Skip to content

Commit 28326e0

Browse files
authored
Merge pull request #5848 from JakaKokosar/multi_target
[ENH] Enable multitarget problem types for OWTestAndScore and OWPredictions
2 parents 93cf05f + f820dea commit 28326e0

File tree

15 files changed

+256
-73
lines changed

15 files changed

+256
-73
lines changed

Orange/base.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Dict
6+
from typing import Callable, Dict, Optional
77

88
import numpy as np
99
import scipy
1010

11-
from Orange.data import Table, Storage, Instance, Value
11+
from Orange.data import Table, Storage, Instance, Value, Domain
1212
from Orange.data.filter import HasClass
1313
from Orange.data.table import DomainTransformationError
1414
from Orange.data.util import one_hot
@@ -86,6 +86,9 @@ class Learner(ReprableWithPreprocessors):
8686
#: A sequence of data preprocessors to apply on data prior to
8787
#: fitting the model
8888
preprocessors = ()
89+
90+
# Note: Do not use this class attribute.
91+
# It remains here for compatibility reasons.
8992
learner_adequacy_err_msg = ''
9093

9194
def __init__(self, preprocessors=None):
@@ -95,6 +98,7 @@ def __init__(self, preprocessors=None):
9598
elif preprocessors:
9699
self.preprocessors = (preprocessors,)
97100

101+
# pylint: disable=R0201
98102
def fit(self, X, Y, W=None):
99103
raise RuntimeError(
100104
"Descendants of Learner must overload method fit or fit_storage")
@@ -106,8 +110,23 @@ def fit_storage(self, data):
106110
return self.fit(X, Y, W)
107111

108112
def __call__(self, data, progress_callback=None):
109-
if not self.check_learner_adequacy(data.domain):
110-
raise ValueError(self.learner_adequacy_err_msg)
113+
114+
for cls in type(self).mro():
115+
if 'incompatibility_reason' in cls.__dict__:
116+
incompatibility_reason = \
117+
self.incompatibility_reason(data.domain) # pylint: disable=assignment-from-none
118+
if incompatibility_reason is not None:
119+
raise ValueError(incompatibility_reason)
120+
break
121+
if 'check_learner_adequacy' in cls.__dict__:
122+
warnings.warn(
123+
"check_learner_adequacy is deprecated and will be removed "
124+
"in upcoming releases. Learners should instead implement "
125+
"the incompatibility_reason method.",
126+
OrangeDeprecationWarning)
127+
if not self.check_learner_adequacy(data.domain):
128+
raise ValueError(self.learner_adequacy_err_msg)
129+
break
111130

112131
origdomain = data.domain
113132

@@ -176,6 +195,11 @@ def active_preprocessors(self):
176195
def check_learner_adequacy(self, _):
177196
return True
178197

198+
# pylint: disable=no-self-use
199+
def incompatibility_reason(self, _: Domain) -> Optional[str]:
200+
"""Return None if a learner can fit domain or string explaining why it can not."""
201+
return None
202+
179203
@property
180204
def name(self):
181205
"""Return a short name derived from Learner type name"""

Orange/classification/base_classification.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55

66

77
class LearnerClassification(Learner):
8-
learner_adequacy_err_msg = "Categorical class variable expected."
98

10-
def check_learner_adequacy(self, domain):
11-
return domain.has_discrete_class
9+
def incompatibility_reason(self, domain):
10+
reason = None
11+
if len(domain.class_vars) > 1 and not self.supports_multiclass:
12+
reason = "Too many target variables."
13+
elif not domain.has_discrete_class:
14+
reason = "Categorical class variable expected."
15+
return reason
1216

1317

1418
class ModelClassification(Model):

Orange/evaluation/clustering.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ def get_fold(self, fold):
3232
class ClusteringScore(Score):
3333
considers_actual = False
3434

35+
@staticmethod
36+
def is_compatible(domain) -> bool:
37+
return True
38+
3539
# pylint: disable=arguments-differ
3640
def from_predicted(self, results, score_function):
3741
# Clustering scores from labels

Orange/evaluation/scoring.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sklearn.metrics as skl_metrics
1717
from sklearn.metrics import confusion_matrix
1818

19-
from Orange.data import DiscreteVariable, ContinuousVariable
19+
from Orange.data import DiscreteVariable, ContinuousVariable, Domain
2020
from Orange.misc.wrapper_meta import WrapperMeta
2121

2222
__all__ = ["CA", "Precision", "Recall", "F1", "PrecisionRecallFSupport", "AUC",
@@ -112,14 +112,26 @@ def from_predicted(results, score_function, **kwargs):
112112
for predicted in results.predicted),
113113
dtype=np.float64, count=len(results.predicted))
114114

115+
@staticmethod
116+
def is_compatible(domain: Domain) -> bool:
117+
raise NotImplementedError
118+
115119

116120
class ClassificationScore(Score, abstract=True):
117121
class_types = (DiscreteVariable, )
118122

123+
@staticmethod
124+
def is_compatible(domain: Domain) -> bool:
125+
return domain.has_discrete_class
126+
119127

120128
class RegressionScore(Score, abstract=True):
121129
class_types = (ContinuousVariable, )
122130

131+
@staticmethod
132+
def is_compatible(domain: Domain) -> bool:
133+
return domain.has_continuous_class
134+
123135

124136
# pylint: disable=invalid-name
125137
class CA(ClassificationScore):

Orange/evaluation/testing.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import sklearn.model_selection as skl
1111

12-
from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
12+
from Orange.data import Domain, ContinuousVariable, DiscreteVariable
1313
from Orange.data.util import get_unique_names
1414

1515
__all__ = ["Results", "CrossValidation", "LeaveOneOut", "TestOnTrainingData",
@@ -37,9 +37,10 @@ def _mp_worker(fold_i, train_data, test_data, learner_i, learner,
3737
train_time = time() - t0
3838
t0 = time()
3939
# testing
40-
if train_data.domain.has_discrete_class:
40+
class_var = train_data.domain.class_var
41+
if class_var and class_var.is_discrete:
4142
predicted, probs = model(test_data, model.ValueProbs)
42-
elif train_data.domain.has_continuous_class:
43+
else:
4344
predicted = model(test_data, model.Value)
4445
test_time = time() - t0
4546
# Different models can fail at any time raising any exception
@@ -269,7 +270,7 @@ def get_augmented_data(self, model_names,
269270
new_meta_vals = np.empty((len(data), 0))
270271
names = [var.name for var in chain(domain.attributes,
271272
domain.metas,
272-
[class_var])]
273+
domain.class_vars)]
273274

274275
if classification:
275276
# predictions
@@ -501,8 +502,7 @@ def prepare_arrays(cls, data, indices):
501502
ptr += len(test)
502503

503504
row_indices = np.concatenate(row_indices, axis=0)
504-
actual = data[row_indices].Y.ravel()
505-
return folds, row_indices, actual
505+
return folds, row_indices, data[row_indices].Y
506506

507507
@staticmethod
508508
def get_indices(data):
@@ -751,7 +751,7 @@ def __call__(self, data, test_data, learners, preprocessor=None,
751751
nrows=len(test_data), learners=learners,
752752
row_indices=np.arange(len(test_data)),
753753
folds=(Ellipsis, ),
754-
actual=test_data.Y.ravel(),
754+
actual=test_data.Y,
755755
score_by_folds=self.score_by_folds,
756756
train_time=np.zeros((len(learners),)),
757757
test_time=np.zeros((len(learners),)))

Orange/preprocess/impute.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __call__(self, data, variable):
224224
variable = data.domain[variable]
225225
domain = domain_with_class_var(data.domain, variable)
226226

227-
if self.learner.check_learner_adequacy(domain):
227+
incompatibility_reason = self.learner.incompatibility_reason(domain)
228+
if incompatibility_reason is None:
228229
data = data.transform(domain)
229230
model = self.learner(data)
230231
assert model.domain.class_var == variable
@@ -239,7 +240,7 @@ def copy(self):
239240

240241
def supports_variable(self, variable):
241242
domain = Orange.data.Domain([], class_vars=variable)
242-
return self.learner.check_learner_adequacy(domain)
243+
return self.learner.incompatibility_reason(domain) is None
243244

244245

245246
def domain_with_class_var(domain, class_var):

Orange/regression/base_regression.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55

66

77
class LearnerRegression(Learner):
8-
learner_adequacy_err_msg = "Numeric class variable expected."
98

10-
def check_learner_adequacy(self, domain):
11-
return domain.has_continuous_class
9+
def incompatibility_reason(self, domain):
10+
reason = None
11+
if len(domain.class_vars) > 1 and not self.supports_multiclass:
12+
reason = "Too many target variables."
13+
elif not domain.has_continuous_class:
14+
reason = "Numeric class variable expected."
15+
return reason
1216

1317

1418
class ModelRegression(Model):

Orange/tests/dummy_learners.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def __init__(self, value, prob):
3333
class DummyMulticlassLearner(SklLearner):
3434
supports_multiclass = True
3535

36-
def check_learner_adequacy(self, domain):
37-
return all(c.is_discrete for c in domain.class_vars)
36+
def incompatibility_reason(self, domain):
37+
reason = 'Not all class variables are discrete'
38+
return None if all(c.is_discrete for c in domain.class_vars) else reason
3839

3940
def fit(self, X, Y, W):
4041
rows, class_vars = Y.shape

Orange/tests/test_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,59 @@
22
# pylint: disable=missing-docstring
33
import pickle
44
import unittest
5+
from distutils.version import LooseVersion
56

7+
import Orange
68
from Orange.base import SklLearner, Learner, Model
79
from Orange.data import Domain, Table
810
from Orange.preprocess import Discretize, Randomize, Continuize
911
from Orange.regression import LinearRegressionLearner
12+
from Orange.util import OrangeDeprecationWarning
13+
14+
15+
class DummyLearnerDeprecated(Learner):
16+
17+
def fit(self, *_, **__):
18+
return unittest.mock.Mock()
19+
20+
def check_learner_adequacy(self, _):
21+
return True
1022

1123

1224
class DummyLearner(Learner):
25+
1326
def fit(self, *_, **__):
1427
return unittest.mock.Mock()
1528

1629

1730
class DummySklLearner(SklLearner):
31+
1832
def fit(self, *_, **__):
1933
return unittest.mock.Mock()
2034

2135

2236
class DummyLearnerPP(Learner):
37+
2338
preprocessors = (Randomize(),)
2439

2540

2641
class TestLearner(unittest.TestCase):
42+
43+
def test_if_deprecation_warning_is_raised(self):
44+
with self.assertWarns(OrangeDeprecationWarning):
45+
DummyLearnerDeprecated()(Table('iris'))
46+
47+
def test_check_learner_adequacy_deprecated(self):
48+
"""This test is to be included in the 3.32 release and will fail in
49+
version 3.34. This serves as a reminder to remove the deprecated method
50+
and this test."""
51+
if LooseVersion(Orange.__version__) >= LooseVersion("3.34"):
52+
self.fail(
53+
"`Orange.base.Learner.check_learner_adequacy` was deprecated in "
54+
"version 3.32, and there have been two minor versions in "
55+
"between. Please remove the deprecated method."
56+
)
57+
2758
def test_uses_default_preprocessors_unless_custom_pps_specified(self):
2859
"""Learners should use their default preprocessors unless custom
2960
preprocessors were passed in to the constructor"""

0 commit comments

Comments
 (0)