Skip to content

Commit 414c29c

Browse files
Fitter: Change params uses default if None
1 parent 571be99 commit 414c29c

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

Orange/modelling/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def __kwargs(self, problem_type):
8383
learner_kwargs = set(
8484
self.__fits__[problem_type].__init__.__code__.co_varnames[1:])
8585
changed_kwargs = self._change_kwargs(self.kwargs, problem_type)
86-
return {k: v for k, v in changed_kwargs.items() if k in learner_kwargs}
86+
# Make sure to remove any params that are set to None and use defaults
87+
filtered_kwargs = {k: v for k, v in changed_kwargs.items() if v}
88+
return {k: v for k, v in filtered_kwargs.items() if k in learner_kwargs}
8789

8890
def _change_kwargs(self, kwargs, problem_type):
8991
"""Handle the kwargs to be passed to the learner before they are used.

Orange/modelling/linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ class SGDLearner(Fitter):
1111

1212
def _change_kwargs(self, kwargs, problem_type):
1313
if problem_type is self.CLASSIFICATION:
14-
kwargs['loss'] = kwargs['classification_loss']
15-
kwargs['epsilon'] = kwargs['classification_epsilon']
14+
kwargs['loss'] = kwargs.get('classification_loss')
15+
kwargs['epsilon'] = kwargs.get('classification_epsilon')
1616
elif problem_type is self.REGRESSION:
17-
kwargs['loss'] = kwargs['regression_loss']
18-
kwargs['epsilon'] = kwargs['regression_epsilon']
17+
kwargs['loss'] = kwargs.get('regression_loss')
18+
kwargs['epsilon'] = kwargs.get('regression_epsilon')
1919
return kwargs

Orange/tests/test_fitter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,43 @@ class DummyFitter(Fitter):
152152
pp_data = fitter.preprocess(self.heart_disease)
153153
self.assertTrue(not any(
154154
isinstance(v, ContinuousVariable) for v in pp_data.domain.variables))
155+
156+
def test_default_kwargs_with_change_kwargs(self):
157+
"""Fallback to default args in case specialized params not specified.
158+
"""
159+
class DummyClassificationLearner(LearnerClassification):
160+
def __init__(self, param='classification_default', **_):
161+
super().__init__()
162+
self.param = param
163+
164+
def fit_storage(self, data):
165+
return DummyModel(self.param)
166+
167+
class DummyRegressionLearner(LearnerRegression):
168+
def __init__(self, param='regression_default', **_):
169+
super().__init__()
170+
self.param = param
171+
172+
def fit_storage(self, data):
173+
return DummyModel(self.param)
174+
175+
class DummyModel:
176+
def __init__(self, param):
177+
self.param = param
178+
179+
class DummyFitter(Fitter):
180+
__fits__ = {'classification': DummyClassificationLearner,
181+
'regression': DummyRegressionLearner}
182+
183+
def _change_kwargs(self, kwargs, problem_type):
184+
if problem_type == self.CLASSIFICATION:
185+
kwargs['param'] = kwargs.get('classification_param')
186+
else:
187+
kwargs['param'] = kwargs.get('regression_param')
188+
return kwargs
189+
190+
learner = DummyFitter()
191+
iris, housing = Table('iris')[:5], Table('housing')[:5]
192+
self.assertEqual(learner(iris).param, 'classification_default')
193+
self.assertEqual(learner(housing).param, 'regression_default')
194+

0 commit comments

Comments
 (0)