@@ -38,17 +38,22 @@ class Fitter(Learner, metaclass=FitterMeta):
3838
3939 def __init__ (self , preprocessors = None , ** kwargs ):
4040 super ().__init__ (preprocessors = preprocessors )
41- self .kwargs = kwargs
41+ self .params = kwargs
4242 # Make sure to pass preprocessor params to individual learners
43- self .kwargs ['preprocessors' ] = preprocessors
44- self .problem_type = None
43+ self .params ['preprocessors' ] = preprocessors
4544 self .__learners = {self .CLASSIFICATION : None , self .REGRESSION : None }
4645
47- def __call__ (self , data ):
48- # Set the appropriate problem type from the data
49- self .problem_type = self .CLASSIFICATION if \
50- data .domain .has_discrete_class else self .REGRESSION
51- return self .get_learner (self .problem_type )(data )
46+ def _fit_model (self , data ):
47+ if data .domain .has_discrete_class :
48+ learner = self .get_learner (self .CLASSIFICATION )
49+ else :
50+ learner = self .get_learner (self .REGRESSION )
51+
52+ if type (self ).fit is Learner .fit :
53+ return learner .fit_storage (data )
54+ else :
55+ X , Y , W = data .X , data .Y , data .W if data .has_weights () else None
56+ return learner .fit (X , Y , W )
5257
5358 def get_learner (self , problem_type ):
5459 """Get the learner for a given problem type."""
@@ -64,7 +69,7 @@ def get_learner(self, problem_type):
6469 def __kwargs (self , problem_type ):
6570 learner_kwargs = set (
6671 self .__fits__ [problem_type ].__init__ .__code__ .co_varnames [1 :])
67- changed_kwargs = self ._change_kwargs (self .kwargs , self . problem_type )
72+ changed_kwargs = self ._change_kwargs (self .params , problem_type )
6873 return {k : v for k , v in changed_kwargs .items () if k in learner_kwargs }
6974
7075 def _change_kwargs (self , kwargs , problem_type ):
@@ -90,9 +95,3 @@ def supports_weights(self):
9095 and self .get_learner (self .CLASSIFICATION ).supports_weights ) and (
9196 hasattr (self .get_learner (self .REGRESSION ), 'supports_weights' )
9297 and self .get_learner (self .REGRESSION ).supports_weights )
93-
94- def __getattr__ (self , item ):
95- # Make parameters accessible on the learner for simpler testing
96- if item in self .kwargs :
97- return self .kwargs [item ]
98- return getattr (self .get_learner (self .problem_type ), item )
0 commit comments