@@ -152,3 +152,43 @@ class DummyFitter(Fitter):
152152 pp_data = fitter .preprocess (self .heart_disease )
153153 self .assertTrue (not any (
154154 isinstance (v , ContinuousVariable ) for v in pp_data .domain .variables ))
155+
156+ def test_default_kwargs_with_change_kwargs (self ):
157+ """Fallback to default args in case specialized params not specified.
158+ """
159+ class DummyClassificationLearner (LearnerClassification ):
160+ def __init__ (self , param = 'classification_default' , ** _ ):
161+ super ().__init__ ()
162+ self .param = param
163+
164+ def fit_storage (self , data ):
165+ return DummyModel (self .param )
166+
167+ class DummyRegressionLearner (LearnerRegression ):
168+ def __init__ (self , param = 'regression_default' , ** _ ):
169+ super ().__init__ ()
170+ self .param = param
171+
172+ def fit_storage (self , data ):
173+ return DummyModel (self .param )
174+
175+ class DummyModel :
176+ def __init__ (self , param ):
177+ self .param = param
178+
179+ class DummyFitter (Fitter ):
180+ __fits__ = {'classification' : DummyClassificationLearner ,
181+ 'regression' : DummyRegressionLearner }
182+
183+ def _change_kwargs (self , kwargs , problem_type ):
184+ if problem_type == self .CLASSIFICATION :
185+ kwargs ['param' ] = kwargs .get ('classification_param' )
186+ else :
187+ kwargs ['param' ] = kwargs .get ('regression_param' )
188+ return kwargs
189+
190+ learner = DummyFitter ()
191+ iris , housing = Table ('iris' )[:5 ], Table ('housing' )[:5 ]
192+ self .assertEqual (learner (iris ).param , 'classification_default' )
193+ self .assertEqual (learner (housing ).param , 'regression_default' )
194+
0 commit comments