Skip to content

Commit 8364ae8

Browse files
authored
Merge pull request #2093 from pavlin-policar/fix-fitter-preprocessors
[FIX] Fitter: Properly delegate preprocessors
2 parents 81afdde + cdf77bb commit 8364ae8

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

Orange/modelling/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ def _fit_model(self, data):
5555
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
5656
return learner.fit(X, Y, W)
5757

58+
def preprocess(self, data):
59+
if data.domain.has_discrete_class:
60+
return self.get_learner(self.CLASSIFICATION).preprocess(data)
61+
else:
62+
return self.get_learner(self.REGRESSION).preprocess(data)
63+
5864
def get_learner(self, problem_type):
5965
"""Get the learner for a given problem type.
6066

Orange/tests/test_fitter.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from unittest.mock import Mock, patch
33

44
from Orange.classification.base_classification import LearnerClassification
5-
from Orange.data import Table
5+
from Orange.data import Table, ContinuousVariable
66
from Orange.evaluation import CrossValidation
77
from Orange.modelling import Fitter
8-
from Orange.preprocess import Randomize
8+
from Orange.preprocess import Randomize, Discretize
99
from Orange.regression.base_regression import LearnerRegression
1010

1111

@@ -130,3 +130,25 @@ def test_correctly_sets_preprocessors_on_learner(self):
130130
def test_n_jobs_fitting(self):
131131
with patch('Orange.evaluation.testing.CrossValidation._MIN_NJOBS_X_SIZE', 1):
132132
CrossValidation(self.heart_disease, [DummyFitter()], k=5, n_jobs=5)
133+
134+
def test_properly_delegates_preprocessing(self):
135+
class DummyClassificationLearner(LearnerClassification):
136+
preprocessors = [Discretize()]
137+
138+
def __init__(self, classification_param=1, **_):
139+
super().__init__()
140+
self.param = classification_param
141+
142+
class DummyFitter(Fitter):
143+
__fits__ = {'classification': DummyClassificationLearner,
144+
'regression': DummyRegressionLearner}
145+
146+
data = self.heart_disease
147+
fitter = DummyFitter()
148+
# Sanity check
149+
self.assertTrue(any(
150+
isinstance(v, ContinuousVariable) for v in data.domain.variables))
151+
# Preprocess the data and check that the discretization was applied
152+
pp_data = fitter.preprocess(self.heart_disease)
153+
self.assertTrue(not any(
154+
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))

0 commit comments

Comments
 (0)