Skip to content

Commit fedb67e

Browse files
committed
learner adequacy check refactor
1 parent bd858ac commit fedb67e

File tree

5 files changed

+32
-22
lines changed

5 files changed

+32
-22
lines changed

Orange/base.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
from collections.abc import Iterable
44
import re
55
import warnings
6-
from typing import Callable, Dict
6+
from typing import Callable, Dict, Optional
77

88
import numpy as np
99
import scipy
1010

11-
from Orange.data import Table, Storage, Instance, Value
11+
from Orange.data import Table, Storage, Instance, Value, Domain
1212
from Orange.data.filter import HasClass
1313
from Orange.data.table import DomainTransformationError
1414
from Orange.data.util import one_hot
@@ -86,6 +86,9 @@ class Learner(ReprableWithPreprocessors):
8686
#: A sequence of data preprocessors to apply on data prior to
8787
#: fitting the model
8888
preprocessors = ()
89+
90+
# Note: Do not use this class attribute.
91+
# It remains here for compatibility reasons.
8992
learner_adequacy_err_msg = ''
9093

9194
def __init__(self, preprocessors=None):
@@ -106,8 +109,9 @@ def fit_storage(self, data):
106109
return self.fit(X, Y, W)
107110

108111
def __call__(self, data, progress_callback=None):
109-
if not self.check_learner_adequacy(data.domain):
110-
raise ValueError(self.learner_adequacy_err_msg)
112+
incompatibility_reason = self.incompatibility_reason(data.domain)
113+
if incompatibility_reason is not None:
114+
raise ValueError(incompatibility_reason)
111115

112116
origdomain = data.domain
113117

@@ -174,8 +178,16 @@ def active_preprocessors(self):
174178
yield from type(self).preprocessors
175179

176180
def check_learner_adequacy(self, _):
181+
warnings.warn(
182+
"check_learner_adequacy is deprecated and will be removed "
183+
"in upcoming releases. Please use incompatibility_reason "
184+
"to check if learner can fit data.",
185+
OrangeDeprecationWarning)
177186
return True
178187

188+
def incompatibility_reason(self, domain: Domain) -> Optional[str]:
189+
return None
190+
179191
@property
180192
def name(self):
181193
"""Return a short name derived from Learner type name"""

Orange/classification/base_classification.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
class LearnerClassification(Learner):
88

9-
def check_learner_adequacy(self, domain):
10-
is_adequate = True
9+
def incompatibility_reason(self, domain):
10+
err_msg = None
1111
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
12+
err_msg = "Too many target variables."
1413
elif not domain.has_discrete_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Categorical class variable expected."
17-
return is_adequate
14+
err_msg = "Categorical class variable expected."
15+
return err_msg
1816

1917

2018
class ModelClassification(Model):

Orange/preprocess/impute.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,8 @@ def __call__(self, data, variable):
224224
variable = data.domain[variable]
225225
domain = domain_with_class_var(data.domain, variable)
226226

227-
if self.learner.check_learner_adequacy(domain):
227+
incompatibility_reason = self.learner.incompatibility_reason(data.domain)
228+
if incompatibility_reason is None:
228229
data = data.transform(domain)
229230
model = self.learner(data)
230231
assert model.domain.class_var == variable
@@ -239,7 +240,7 @@ def copy(self):
239240

240241
def supports_variable(self, variable):
241242
domain = Orange.data.Domain([], class_vars=variable)
242-
return self.learner.check_learner_adequacy(domain)
243+
return self.learner.incompatibility_reason(domain) is None
243244

244245

245246
def domain_with_class_var(domain, class_var):

Orange/regression/base_regression.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
class LearnerRegression(Learner):
88

9-
def check_learner_adequacy(self, domain):
10-
is_adequate = True
9+
def incompatibility_reason(self, domain):
10+
err_msg = None
1111
if len(domain.class_vars) > 1:
12-
is_adequate = False
13-
self.learner_adequacy_err_msg = "Too many target variables."
12+
err_msg = "Too many target variables."
1413
elif not domain.has_continuous_class:
15-
is_adequate = False
16-
self.learner_adequacy_err_msg = "Numeric class variable expected."
17-
return is_adequate
14+
err_msg = "Numeric class variable expected."
15+
return err_msg
1816

1917

2018
class ModelRegression(Model):

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ def check_data(self):
246246
self.Error.sparse_not_supported.clear()
247247
if self.data is not None and self.learner is not None:
248248
self.Error.data_error.clear()
249-
if not self.learner.check_learner_adequacy(self.data.domain):
250-
self.Error.data_error(self.learner.learner_adequacy_err_msg)
249+
incompatibility_reason = self.learner.incompatibility_reason(self.data.domain)
250+
if incompatibility_reason is not None:
251+
self.Error.data_error(incompatibility_reason)
251252
elif not len(self.data):
252253
self.Error.data_error("Dataset is empty.")
253254
elif len(ut.unique(self.data.Y)) < 2:

0 commit comments

Comments
 (0)