Skip to content

Commit 1fc9cd8

Browse files
authored
Merge pull request #2127 from pavlin-policar/fitter-change-kwargs-default-params
[FIX] Fitter: Change params uses default if None
2 parents a420f95 + 417e550 commit 1fc9cd8

File tree

3 files changed

+51
-10
lines changed

3 files changed

+51
-10
lines changed

Orange/modelling/base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Fitter(Learner, metaclass=FitterMeta):
3030
learners.
3131
3232
"""
33-
__fits__ = None
33+
__fits__ = {}
3434
__returns__ = Model
3535

3636
# Constants to indicate what kind of problem we're dealing with
@@ -83,7 +83,9 @@ def __kwargs(self, problem_type):
8383
learner_kwargs = set(
8484
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
8585
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
86-
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
86+
# Make sure to remove any params that are set to None and use defaults
87+
filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v is not None}
88+
return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs}
8789

8890
def _change_kwargs(self, kwargs, problem_type):
8991
"""Handle the kwargs to be passed to the learner before they are used.
@@ -104,10 +106,9 @@ def supports_weights(self):
104106
"""The fitter supports weights if both the classification and
105107
regression learners support weights."""
106108
return (
107-
hasattr(self.get_learner(self.CLASSIFICATION), 'supports_weights')
108-
and self.get_learner(self.CLASSIFICATION).supports_weights) and (
109-
hasattr(self.get_learner(self.REGRESSION), 'supports_weights')
110-
and self.get_learner(self.REGRESSION).supports_weights)
109+
getattr(self.get_learner(self.CLASSIFICATION), 'supports_weights', False) and
110+
getattr(self.get_learner(self.REGRESSION), 'supports_weights', False)
111+
)
111112

112113
@property
113114
def params(self):

Orange/modelling/linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ class SGDLearner(Fitter):
1313

1414
def _change_kwargs(self, kwargs, problem_type):
1515
if problem_type is self.CLASSIFICATION:
16-
kwargs['loss'] = kwargs['classification_loss']
17-
kwargs['epsilon'] = kwargs['classification_epsilon']
16+
kwargs['loss'] = kwargs.get('classification_loss')
17+
kwargs['epsilon'] = kwargs.get('classification_epsilon')
1818
elif problem_type is self.REGRESSION:
19-
kwargs['loss'] = kwargs['regression_loss']
20-
kwargs['epsilon'] = kwargs['regression_epsilon']
19+
kwargs['loss'] = kwargs.get('regression_loss')
20+
kwargs['epsilon'] = kwargs.get('regression_epsilon')
2121
return kwargs

Orange/tests/test_fitter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,43 @@ class DummyFitter(Fitter):
152152
pp_data = fitter.preprocess(self.heart_disease)
153153
self.assertTrue(not any(
154154
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))
155+
156+
def test_default_kwargs_with_change_kwargs(self):
157+
"""Fallback to default args in case specialized params not specified.
158+
"""
159+
class DummyClassificationLearner(LearnerClassification):
160+
def __init__(self, param='classification_default', **_):
161+
super().__init__()
162+
self.param = param
163+
164+
def fit_storage(self, data):
165+
return DummyModel(self.param)
166+
167+
class DummyRegressionLearner(LearnerRegression):
168+
def __init__(self, param='regression_default', **_):
169+
super().__init__()
170+
self.param = param
171+
172+
def fit_storage(self, data):
173+
return DummyModel(self.param)
174+
175+
class DummyModel:
176+
def __init__(self, param):
177+
self.param = param
178+
179+
class DummyFitter(Fitter):
180+
__fits__ = {'classification': DummyClassificationLearner,
181+
'regression': DummyRegressionLearner}
182+
183+
def _change_kwargs(self, kwargs, problem_type):
184+
if problem_type == self.CLASSIFICATION:
185+
kwargs['param'] = kwargs.get('classification_param')
186+
else:
187+
kwargs['param'] = kwargs.get('regression_param')
188+
return kwargs
189+
190+
learner = DummyFitter()
191+
iris, housing = Table('iris')[:5], Table('housing')[:5]
192+
self.assertEqual(learner(iris).param, 'classification_default')
193+
self.assertEqual(learner(housing).param, 'regression_default')
194+

0 commit comments

Comments
 (0)