Skip to content

Commit e1e7419

Browse files
committed
learner adequacy check refactor
1 parent 19c1be7 commit e1e7419

File tree

7 files changed

+100
-26
lines changed

7 files changed

+100
-26
lines changed

Orange/base.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Dict
6+
from typing import Callable, Dict, Optional
77

88
import numpy as np
99
import scipy
1010

11-
from Orange.data import Table, Storage, Instance, Value
11+
from Orange.data import Table, Storage, Instance, Value, Domain
1212
from Orange.data.filter import HasClass
1313
from Orange.data.table import DomainTransformationError
1414
from Orange.data.util import one_hot
@@ -86,6 +86,9 @@ class Learner(ReprableWithPreprocessors):
8686
#: A sequence of data preprocessors to apply on data prior to
8787
#: fitting the model
8888
preprocessors = ()
89+
90+
# Note: Do not use this class attribute.
91+
# It remains here for compatibility reasons.
8992
learner_adequacy_err_msg = ''
9093

9194
def __init__(self, preprocessors=None):
@@ -95,6 +98,7 @@ def __init__(self, preprocessors=None):
9598
elif preprocessors:
9699
self.preprocessors = (preprocessors,)
97100

101+
# pylint: disable=R0201
98102
def fit(self, X, Y, W=None):
99103
raise RuntimeError(
100104
"Descendants of Learner must overload method fit or fit_storage")
@@ -106,8 +110,23 @@ def fit_storage(self, data):
106110
return self.fit(X, Y, W)
107111

108112
def __call__(self, data, progress_callback=None):
109-
if not self.check_learner_adequacy(data.domain):
110-
raise ValueError(self.learner_adequacy_err_msg)
113+
114+
for cls in type(self).mro():
115+
if 'incompatibility_reason' in cls.__dict__:
116+
incompatibility_reason = \
117+
self.incompatibility_reason(data.domain) # pylint: disable=assignment-from-none
118+
if incompatibility_reason is not None:
119+
raise ValueError(incompatibility_reason)
120+
break
121+
if 'check_learner_adequacy' in cls.__dict__:
122+
warnings.warn(
123+
"check_learner_adequacy is deprecated and will be removed "
124+
"in upcoming releases. Learners should instead implement "
125+
"the incompatibility_reason method.",
126+
OrangeDeprecationWarning)
127+
if not self.check_learner_adequacy(data.domain):
128+
raise ValueError(self.learner_adequacy_err_msg)
129+
break
111130

112131
origdomain = data.domain
113132

@@ -176,6 +195,11 @@ def active_preprocessors(self):
176195
def check_learner_adequacy(self, _):
177196
return True
178197

198+
# pylint: disable=no-self-use
199+
def incompatibility_reason(self, _: Domain) -> Optional[str]:
200+
"""Return None if a learner can fit domain or string explaining why it can not."""
201+
return None
202+
179203
@property
180204
def name(self):
181205
"""Return a short name derived from Learner type name"""

Orange/classification/base_classification.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
class LearnerClassification(Learner):
88

9-
def check_learner_adequacy(self, domain):
10-
is_adequate = True
11-
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
9+
def incompatibility_reason(self, domain):
10+
reason = None
11+
if len(domain.class_vars) > 1 and not self.supports_multiclass:
12+
reason = "Too many target variables."
1413
elif not domain.has_discrete_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Categorical class variable expected."
17-
return is_adequate
14+
reason = "Categorical class variable expected."
15+
return reason
1816

1917

2018
class ModelClassification(Model):

Orange/preprocess/impute.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __call__(self, data, variable):
224224
variable = data.domain[variable]
225225
domain = domain_with_class_var(data.domain, variable)
226226

227-
if self.learner.check_learner_adequacy(domain):
227+
incompatibility_reason = self.learner.incompatibility_reason(domain)
228+
if incompatibility_reason is None:
228229
data = data.transform(domain)
229230
model = self.learner(data)
230231
assert model.domain.class_var == variable
@@ -239,7 +240,7 @@ def copy(self):
239240

240241
def supports_variable(self, variable):
241242
domain = Orange.data.Domain([], class_vars=variable)
242-
return self.learner.check_learner_adequacy(domain)
243+
return self.learner.incompatibility_reason(domain) is None
243244

244245

245246
def domain_with_class_var(domain, class_var):

Orange/regression/base_regression.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
class LearnerRegression(Learner):
88

9-
def check_learner_adequacy(self, domain):
10-
is_adequate = True
11-
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
9+
def incompatibility_reason(self, domain):
10+
reason = None
11+
if len(domain.class_vars) > 1 and not self.supports_multiclass:
12+
reason = "Too many target variables."
1413
elif not domain.has_continuous_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Numeric class variable expected."
17-
return is_adequate
14+
reason = "Numeric class variable expected."
15+
return reason
1816

1917

2018
class ModelRegression(Model):

Orange/tests/dummy_learners.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ def __init__(self, value, prob):
3333
class DummyMulticlassLearner(SklLearner):
3434
supports_multiclass = True
3535

36-
def check_learner_adequacy(self, domain):
37-
return all(c.is_discrete for c in domain.class_vars)
36+
def incompatibility_reason(self, domain):
37+
reason = 'Not all class variables are discrete'
38+
return None if all(c.is_discrete for c in domain.class_vars) else reason
3839

3940
def fit(self, X, Y, W):
4041
rows, class_vars = Y.shape

Orange/tests/test_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,59 @@
22
# pylint: disable=missing-docstring
33
import pickle
44
import unittest
5+
from distutils.version import LooseVersion
56

7+
import Orange
68
from Orange.base import SklLearner, Learner, Model
79
from Orange.data import Domain, Table
810
from Orange.preprocess import Discretize, Randomize, Continuize
911
from Orange.regression import LinearRegressionLearner
12+
from Orange.util import OrangeDeprecationWarning
13+
14+
15+
class DummyLearnerDeprecated(Learner):
16+
17+
def fit(self, *_, **__):
18+
return unittest.mock.Mock()
19+
20+
def check_learner_adequacy(self, _):
21+
return True
1022

1123

1224
class DummyLearner(Learner):
25+
1326
def fit(self, *_, **__):
1427
return unittest.mock.Mock()
1528

1629

1730
class DummySklLearner(SklLearner):
31+
1832
def fit(self, *_, **__):
1933
return unittest.mock.Mock()
2034

2135

2236
class DummyLearnerPP(Learner):
37+
2338
preprocessors = (Randomize(),)
2439

2540

2641
class TestLearner(unittest.TestCase):
42+
43+
def test_if_deprecation_warning_is_raised(self):
44+
with self.assertWarns(OrangeDeprecationWarning):
45+
DummyLearnerDeprecated()(Table('iris'))
46+
47+
def test_check_learner_adequacy_deprecated(self):
48+
"""This test is to be included in the 3.32 release and will fail in
49+
version 3.34. This serves as a reminder to remove the deprecated method
50+
and this test."""
51+
if LooseVersion(Orange.__version__) >= LooseVersion("3.34"):
52+
self.fail(
53+
"`Orange.base.Learner.check_learner_adequacy` was deprecated in "
54+
"version 3.32, and there have been two minor versions in "
55+
"between. Please remove the deprecated method."
56+
)
57+
2758
def test_uses_default_preprocessors_unless_custom_pps_specified(self):
2859
"""Learners should use their default preprocessors unless custom
2960
preprocessors were passed in to the constructor"""

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from copy import deepcopy
2+
import warnings
23

34
from AnyQt.QtCore import QTimer, Qt
45

@@ -12,6 +13,7 @@
1213
from Orange.widgets.utils.signals import Output, Input
1314
from Orange.widgets.utils.sql import check_sql_input
1415
from Orange.widgets.widget import OWWidget, WidgetMetaClass, Msg
16+
from Orange.util import OrangeDeprecationWarning
1517

1618

1719
class OWBaseLearnerMeta(WidgetMetaClass):
@@ -246,8 +248,26 @@ def check_data(self):
246248
self.Error.sparse_not_supported.clear()
247249
if self.data is not None and self.learner is not None:
248250
self.Error.data_error.clear()
249-
if not self.learner.check_learner_adequacy(self.data.domain):
250-
self.Error.data_error(self.learner.learner_adequacy_err_msg)
251+
252+
incompatibility_reason = None
253+
for cls in type(self.learner).mro():
254+
if 'incompatibility_reason' in cls.__dict__:
255+
# pylint: disable=assignment-from-none
256+
incompatibility_reason = \
257+
self.learner.incompatibility_reason(self.data.domain)
258+
break
259+
if 'check_learner_adequacy' in cls.__dict__:
260+
warnings.warn(
261+
"check_learner_adequacy is deprecated and will be removed "
262+
"in upcoming releases. Learners should instead implement "
263+
"the incompatibility_reason method.",
264+
OrangeDeprecationWarning)
265+
if not self.learner.check_learner_adequacy(self.data.domain):
266+
incompatibility_reason = self.learner.learner_adequacy_err_msg
267+
break
268+
269+
if incompatibility_reason is not None:
270+
self.Error.data_error(incompatibility_reason)
251271
elif not len(self.data):
252272
self.Error.data_error("Dataset is empty.")
253273
elif len(ut.unique(self.data.Y)) < 2:
@@ -258,6 +278,7 @@ def check_data(self):
258278
self.Error.sparse_not_supported()
259279
else:
260280
self.valid_data = True
281+
261282
return self.valid_data
262283

263284
def settings_changed(self, *args, **kwargs):

0 commit comments

Comments
 (0)