Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions Orange/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from collections.abc import Iterable
import re
import warnings
from typing import Callable, Dict
from typing import Callable, Dict, Optional

import numpy as np
import scipy

from Orange.data import Table, Storage, Instance, Value
from Orange.data import Table, Storage, Instance, Value, Domain
from Orange.data.filter import HasClass
from Orange.data.table import DomainTransformationError
from Orange.data.util import one_hot
Expand Down Expand Up @@ -86,6 +86,9 @@ class Learner(ReprableWithPreprocessors):
#: A sequence of data preprocessors to apply on data prior to
#: fitting the model
preprocessors = ()

# Note: Do not use this class attribute.
# It remains here for compatibility reasons.
learner_adequacy_err_msg = ''

def __init__(self, preprocessors=None):
Expand All @@ -95,6 +98,7 @@ def __init__(self, preprocessors=None):
elif preprocessors:
self.preprocessors = (preprocessors,)

# pylint: disable=R0201
def fit(self, X, Y, W=None):
raise RuntimeError(
"Descendants of Learner must overload method fit or fit_storage")
Expand All @@ -106,8 +110,23 @@ def fit_storage(self, data):
return self.fit(X, Y, W)

def __call__(self, data, progress_callback=None):
if not self.check_learner_adequacy(data.domain):
raise ValueError(self.learner_adequacy_err_msg)

for cls in type(self).mro():
if 'incompatibility_reason' in cls.__dict__:
incompatibility_reason = \
self.incompatibility_reason(data.domain) # pylint: disable=assignment-from-none
if incompatibility_reason is not None:
raise ValueError(incompatibility_reason)
break
if 'check_learner_adequacy' in cls.__dict__:
warnings.warn(
"check_learner_adequacy is deprecated and will be removed "
"in upcoming releases. Learners should instead implement "
"the incompatibility_reason method.",
OrangeDeprecationWarning)
if not self.check_learner_adequacy(data.domain):
raise ValueError(self.learner_adequacy_err_msg)
break

origdomain = data.domain

Expand Down Expand Up @@ -176,6 +195,11 @@ def active_preprocessors(self):
def check_learner_adequacy(self, _):
return True

# pylint: disable=no-self-use
def incompatibility_reason(self, _: Domain) -> Optional[str]:
"""Return None if a learner can fit domain or string explaining why it can not."""
return None

@property
def name(self):
"""Return a short name derived from Learner type name"""
Expand Down
10 changes: 7 additions & 3 deletions Orange/classification/base_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@


class LearnerClassification(Learner):
learner_adequacy_err_msg = "Categorical class variable expected."

def check_learner_adequacy(self, domain):
return domain.has_discrete_class
def incompatibility_reason(self, domain):
reason = None
if len(domain.class_vars) > 1 and not self.supports_multiclass:
reason = "Too many target variables."
elif not domain.has_discrete_class:
reason = "Categorical class variable expected."
return reason


class ModelClassification(Model):
Expand Down
4 changes: 4 additions & 0 deletions Orange/evaluation/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def get_fold(self, fold):
class ClusteringScore(Score):
considers_actual = False

@staticmethod
def is_compatible(domain) -> bool:
return True

# pylint: disable=arguments-differ
def from_predicted(self, results, score_function):
# Clustering scores from labels
Expand Down
14 changes: 13 additions & 1 deletion Orange/evaluation/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sklearn.metrics as skl_metrics
from sklearn.metrics import confusion_matrix

from Orange.data import DiscreteVariable, ContinuousVariable
from Orange.data import DiscreteVariable, ContinuousVariable, Domain
from Orange.misc.wrapper_meta import WrapperMeta

__all__ = ["CA", "Precision", "Recall", "F1", "PrecisionRecallFSupport", "AUC",
Expand Down Expand Up @@ -112,14 +112,26 @@ def from_predicted(results, score_function, **kwargs):
for predicted in results.predicted),
dtype=np.float64, count=len(results.predicted))

@staticmethod
def is_compatible(domain: Domain) -> bool:
raise NotImplementedError


class ClassificationScore(Score, abstract=True):
class_types = (DiscreteVariable, )

@staticmethod
def is_compatible(domain: Domain) -> bool:
return domain.has_discrete_class


class RegressionScore(Score, abstract=True):
class_types = (ContinuousVariable, )

@staticmethod
def is_compatible(domain: Domain) -> bool:
return domain.has_continuous_class


# pylint: disable=invalid-name
class CA(ClassificationScore):
Expand Down
14 changes: 7 additions & 7 deletions Orange/evaluation/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import sklearn.model_selection as skl

from Orange.data import Table, Domain, ContinuousVariable, DiscreteVariable
from Orange.data import Domain, ContinuousVariable, DiscreteVariable
from Orange.data.util import get_unique_names

__all__ = ["Results", "CrossValidation", "LeaveOneOut", "TestOnTrainingData",
Expand Down Expand Up @@ -37,9 +37,10 @@ def _mp_worker(fold_i, train_data, test_data, learner_i, learner,
train_time = time() - t0
t0 = time()
# testing
if train_data.domain.has_discrete_class:
class_var = train_data.domain.class_var
if class_var and class_var.is_discrete:
predicted, probs = model(test_data, model.ValueProbs)
elif train_data.domain.has_continuous_class:
else:
predicted = model(test_data, model.Value)
test_time = time() - t0
# Different models can fail at any time raising any exception
Expand Down Expand Up @@ -269,7 +270,7 @@ def get_augmented_data(self, model_names,
new_meta_vals = np.empty((len(data), 0))
names = [var.name for var in chain(domain.attributes,
domain.metas,
[class_var])]
domain.class_vars)]

if classification:
# predictions
Expand Down Expand Up @@ -501,8 +502,7 @@ def prepare_arrays(cls, data, indices):
ptr += len(test)

row_indices = np.concatenate(row_indices, axis=0)
actual = data[row_indices].Y.ravel()
return folds, row_indices, actual
return folds, row_indices, data[row_indices].Y

@staticmethod
def get_indices(data):
Expand Down Expand Up @@ -751,7 +751,7 @@ def __call__(self, data, test_data, learners, preprocessor=None,
nrows=len(test_data), learners=learners,
row_indices=np.arange(len(test_data)),
folds=(Ellipsis, ),
actual=test_data.Y.ravel(),
actual=test_data.Y,
score_by_folds=self.score_by_folds,
train_time=np.zeros((len(learners),)),
test_time=np.zeros((len(learners),)))
Expand Down
5 changes: 3 additions & 2 deletions Orange/preprocess/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __call__(self, data, variable):
variable = data.domain[variable]
domain = domain_with_class_var(data.domain, variable)

if self.learner.check_learner_adequacy(domain):
incompatibility_reason = self.learner.incompatibility_reason(domain)
if incompatibility_reason is None:
data = data.transform(domain)
model = self.learner(data)
assert model.domain.class_var == variable
Expand All @@ -239,7 +240,7 @@ def copy(self):

def supports_variable(self, variable):
domain = Orange.data.Domain([], class_vars=variable)
return self.learner.check_learner_adequacy(domain)
return self.learner.incompatibility_reason(domain) is None


def domain_with_class_var(domain, class_var):
Expand Down
10 changes: 7 additions & 3 deletions Orange/regression/base_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@


class LearnerRegression(Learner):
learner_adequacy_err_msg = "Numeric class variable expected."

def check_learner_adequacy(self, domain):
return domain.has_continuous_class
def incompatibility_reason(self, domain):
reason = None
if len(domain.class_vars) > 1 and not self.supports_multiclass:
reason = "Too many target variables."
elif not domain.has_continuous_class:
reason = "Numeric class variable expected."
return reason


class ModelRegression(Model):
Expand Down
5 changes: 3 additions & 2 deletions Orange/tests/dummy_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def __init__(self, value, prob):
class DummyMulticlassLearner(SklLearner):
supports_multiclass = True

def check_learner_adequacy(self, domain):
return all(c.is_discrete for c in domain.class_vars)
def incompatibility_reason(self, domain):
reason = 'Not all class variables are discrete'
return None if all(c.is_discrete for c in domain.class_vars) else reason

def fit(self, X, Y, W):
rows, class_vars = Y.shape
Expand Down
31 changes: 31 additions & 0 deletions Orange/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,59 @@
# pylint: disable=missing-docstring
import pickle
import unittest
from distutils.version import LooseVersion

import Orange
from Orange.base import SklLearner, Learner, Model
from Orange.data import Domain, Table
from Orange.preprocess import Discretize, Randomize, Continuize
from Orange.regression import LinearRegressionLearner
from Orange.util import OrangeDeprecationWarning


class DummyLearnerDeprecated(Learner):

def fit(self, *_, **__):
return unittest.mock.Mock()

def check_learner_adequacy(self, _):
return True


class DummyLearner(Learner):

def fit(self, *_, **__):
return unittest.mock.Mock()


class DummySklLearner(SklLearner):

def fit(self, *_, **__):
return unittest.mock.Mock()


class DummyLearnerPP(Learner):

preprocessors = (Randomize(),)


class TestLearner(unittest.TestCase):

def test_if_deprecation_warning_is_raised(self):
with self.assertWarns(OrangeDeprecationWarning):
DummyLearnerDeprecated()(Table('iris'))

def test_check_learner_adequacy_deprecated(self):
"""This test is to be included in the 3.32 release and will fail in
version 3.34. This serves as a reminder to remove the deprecated method
and this test."""
if LooseVersion(Orange.__version__) >= LooseVersion("3.34"):
self.fail(
"`Orange.base.Learner.check_learner_adequacy` was deprecated in "
"version 3.32, and there have been two minor versions in "
"between. Please remove the deprecated method."
)

def test_uses_default_preprocessors_unless_custom_pps_specified(self):
"""Learners should use their default preprocessors unless custom
preprocessors were passed in to the constructor"""
Expand Down
Loading