Skip to content

Commit e8423b3

Browse files
Fitter: Remove state from fitter
1 parent 1f1a3a4 commit e8423b3

File tree

2 files changed

+22
-21
lines changed

2 files changed

+22
-21
lines changed

Orange/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,21 @@ def __call__(self, data):
114114
self.__class__.__name__)
115115

116116
self.domain = data.domain
117-
118-
if type(self).fit is Learner.fit:
119-
model = self.fit_storage(data)
120-
else:
121-
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
122-
model = self.fit(X, Y, W)
117+
model = self._fit_model(data)
123118
model.domain = data.domain
124119
model.supports_multiclass = self.supports_multiclass
125120
model.name = self.name
126121
model.original_domain = origdomain
127122
model.original_data = origdata
128123
return model
129124

125+
def _fit_model(self, data):
126+
if type(self).fit is Learner.fit:
127+
return self.fit_storage(data)
128+
else:
129+
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
130+
return self.fit(X, Y, W)
131+
130132
def preprocess(self, data):
131133
"""Apply the `preprocessors` to the data"""
132134
for pp in self.active_preprocessors:

Orange/modelling/base.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,22 @@ class Fitter(Learner, metaclass=FitterMeta):
3838

3939
def __init__(self, preprocessors=None, **kwargs):
4040
super().__init__(preprocessors=preprocessors)
41-
self.kwargs = kwargs
41+
self.params = kwargs
4242
# Make sure to pass preprocessor params to individual learners
43-
self.kwargs['preprocessors'] = preprocessors
44-
self.problem_type = None
43+
self.params['preprocessors'] = preprocessors
4544
self.__learners = {self.CLASSIFICATION: None, self.REGRESSION: None}
4645

47-
def __call__(self, data):
48-
# Set the appropriate problem type from the data
49-
self.problem_type = self.CLASSIFICATION if \
50-
data.domain.has_discrete_class else self.REGRESSION
51-
return self.get_learner(self.problem_type)(data)
46+
def _fit_model(self, data):
47+
if data.domain.has_discrete_class:
48+
learner = self.get_learner(self.CLASSIFICATION)
49+
else:
50+
learner = self.get_learner(self.REGRESSION)
51+
52+
if type(self).fit is Learner.fit:
53+
return learner.fit_storage(data)
54+
else:
55+
X, Y, W = data.X, data.Y, data.W if data.has_weights() else None
56+
return learner.fit(X, Y, W)
5257

5358
def get_learner(self, problem_type):
5459
"""Get the learner for a given problem type."""
@@ -64,7 +69,7 @@ def get_learner(self, problem_type):
6469
def __kwargs(self, problem_type):
6570
learner_kwargs = set(
6671
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
67-
changed_kwargs = self._change_kwargs(self.kwargs, self.problem_type)
72+
changed_kwargs = self._change_kwargs(self.params, problem_type)
6873
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
6974

7075
def _change_kwargs(self, kwargs, problem_type):
@@ -90,9 +95,3 @@ def supports_weights(self):
9095
and self.get_learner(self.CLASSIFICATION).supports_weights) and (
9196
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
9297
and self.get_learner(self.REGRESSION).supports_weights)
93-
94-
def __getattr__(self, item):
95-
# Make parameters accessible on the learner for simpler testing
96-
if item in self.kwargs:
97-
return self.kwargs[item]
98-
return getattr(self.get_learner(self.problem_type), item)

0 commit comments

Comments
 (0)