44
55"""
66import copy
7- import pickle
87
98import pandas as pd
9+ import pickle
1010
1111from deeptables .models .config import ModelConfig
1212from deeptables .models .deeptable import DeepTable
2020logger = dt_logging .get_logger (__name__ )
2121
2222
23+ def _to_hp (v ):
24+ if isinstance (v , (list , tuple )):
25+ v = Choice (v )
26+ return v
27+
28+
2329class DTModuleSpace (ModuleSpace ):
2430 def __init__ (self , space = None , name = None , ** hyperparams ):
2531 ModuleSpace .__init__ (self , space , name , ** hyperparams )
@@ -37,13 +43,16 @@ def _on_params_ready(self):
3743
3844
3945class DTFit (ModuleSpace ):
40- def __init__ (self , batch_size = None , epochs = None , space = None , name = None , ** hyperparams ):
41- if batch_size is None :
42- batch_size = Choice ([128 , 256 ])
43- hyperparams ['batch_size' ] = batch_size
46+ def __init__ (self , space = None , name = None , ** hyperparams ):
47+ # if batch_size is None:
48+ # batch_size = Choice([128, 256])
49+ # hyperparams['batch_size'] = batch_size
50+ #
51+ # if epochs is not None:
52+ # hyperparams['epochs'] = epochs
4453
45- if epochs is not None :
46- hyperparams ['epochs' ] = epochs
54+ for k , v in hyperparams . items () :
55+ hyperparams [k ] = _to_hp ( v )
4756
4857 ModuleSpace .__init__ (self , space , name , ** hyperparams )
4958 self .space .fit_params = self
@@ -62,24 +71,24 @@ class DnnModule(ModuleSpace):
6271 def __init__ (self , hidden_units = None , reduce_factor = None , dnn_dropout = None , use_bn = None , dnn_layers = None ,
6372 activation = None , space = None , name = None , ** hyperparams ):
6473 if hidden_units is None :
65- hidden_units = Choice ( [100 , 200 , 300 , 500 , 800 , 1000 ])
66- hyperparams ['hidden_units' ] = hidden_units
74+ hidden_units = [100 , 200 , 300 , 500 , 800 , 1000 ]
75+ hyperparams ['hidden_units' ] = _to_hp ( hidden_units )
6776
6877 if reduce_factor is None :
69- reduce_factor = Choice ( [1 , 0.8 , 0.5 ])
70- hyperparams ['reduce_factor' ] = reduce_factor
78+ reduce_factor = [1 , 0.8 , 0.5 ]
79+ hyperparams ['reduce_factor' ] = _to_hp ( reduce_factor )
7180
7281 if dnn_dropout is None :
73- dnn_dropout = Choice ( [0 , 0.1 , 0.3 , 0.5 ])
74- hyperparams ['dnn_dropout' ] = dnn_dropout
82+ dnn_dropout = [0 , 0.1 , 0.3 , 0.5 ]
83+ hyperparams ['dnn_dropout' ] = _to_hp ( dnn_dropout )
7584
7685 if use_bn is None :
7786 use_bn = Bool ()
7887 hyperparams ['use_bn' ] = use_bn
7988
8089 if dnn_layers is None :
81- dnn_layers = Choice ( [1 , 2 , 3 ])
82- hyperparams ['dnn_layers' ] = dnn_layers
90+ dnn_layers = [1 , 2 , 3 ]
91+ hyperparams ['dnn_layers' ] = _to_hp ( dnn_layers )
8392
8493 if activation is None :
8594 activation = 'relu'
@@ -141,22 +150,25 @@ def summary(self):
141150 # logger.info(ex)
142151
143152 def fit (self , X , y , eval_set = None , pos_label = None , n_jobs = 1 , ** kwargs ):
144- fit_params = self .space_sample .__dict__ .get ('fit_params' )
145- if fit_params is not None :
146- kwargs .update (fit_params .param_values )
153+ # fit_params = self.space_sample.__dict__.get('fit_params')
154+ # if fit_params is not None:
155+ # kwargs.update(fit_params.param_values)
147156 if kwargs .get ('cross_validation' ) is not None :
148157 kwargs .pop ('cross_validation' )
149158 self .model .fit_cross_validation (X , y , n_jobs = n_jobs , ** kwargs )
150159 else :
151- self .model .fit (X , y , ** kwargs )
160+ fit_kwargs = self .space_sample .fit_params .param_values .copy ()
161+ fit_kwargs .update (kwargs )
162+ self .model .fit (X , y , ** fit_kwargs )
152163
153164 self .classes_ = getattr (self .model , 'classes_' , None )
154165 return self
155166
156167 def fit_cross_validation (self , X , y , eval_set = None , metrics = None , pos_label = None , ** kwargs ):
157168 assert isinstance (metrics , (list , tuple ))
158-
159- oof_proba , _ , _ , oof_scores = self .model .fit_cross_validation (X , y , oof_metrics = metrics , ** kwargs )
169+ fit_kwargs = self .space_sample .fit_params .param_values .copy ()
170+ fit_kwargs .update (kwargs )
171+ oof_proba , _ , _ , oof_scores = self .model .fit_cross_validation (X , y , oof_metrics = metrics , ** fit_kwargs )
160172
161173 # calc final score with mean
162174 scores = pd .concat ([pd .Series (s ) for s in oof_scores ], axis = 1 ).mean (axis = 1 ).to_dict ()
@@ -360,6 +372,7 @@ def tiny_dt_space(**hyperparams):
360372 use_bn = False ,
361373 dnn_layers = 2 ,
362374 activation = 'relu' )(dt_module )
375+ hyperparams ['batch_size' ] = [64 , 100 ]
363376 fit = DTFit (** hyperparams )(dt_module )
364377
365378 return space
0 commit comments