@@ -28,10 +28,45 @@ def _xgbclassifier_default(trial: optuna.trial.Trial):
2828 return param
2929
3030
31+ def _lgbmclassifier_default (trial : optuna .trial .Trial ):
32+ # TODO: using LightGBMTuner
33+ params = {
34+ 'boosting_type' : trial .suggest_categorical ('boosting' , ['gbdt' , 'dart' , 'goss' ]),
35+ 'objective' : 'binary' ,
36+ 'metric' : ['binary' , 'binary_error' , 'auc' ],
37+ 'num_leaves' : trial .suggest_int ("num_leaves" , 10 , 500 ),
38+ 'learning_rate' : trial .suggest_loguniform ("learning_rate" , 1e-5 , 1 ),
39+ 'feature_fraction' : trial .suggest_uniform ("feature_fraction" , 0.0 , 1.0 ),
40+ }
41+ if params ['boosting_type' ] == 'dart' :
42+ params ['drop_rate' ] = trial .suggest_loguniform ('drop_rate' , 1e-8 , 1.0 )
43+ params ['skip_drop' ] = trial .suggest_loguniform ('skip_drop' , 1e-8 , 1.0 )
44+ if params ['boosting_type' ] == 'goss' :
45+ params ['top_rate' ] = trial .suggest_uniform ('top_rate' , 0.0 , 1.0 )
46+ params ['other_rate' ] = trial .suggest_uniform ('other_rate' , 0.0 , 1.0 - params ['top_rate' ])
47+
48+ return params
49+
50+
51+ def _catboostclassifier_default (trial : optuna .trial .Trial ):
52+ params = {
53+ 'iterations' : trial .suggest_int ('iterations' , 50 , 300 ),
54+ 'depth' : trial .suggest_int ('depth' , 4 , 10 ),
55+ 'learning_rate' : trial .suggest_loguniform ('learning_rate' , 0.01 , 0.3 ),
56+ 'random_strength' : trial .suggest_int ('random_strength' , 0 , 100 ),
57+ 'bagging_temperature' : trial .suggest_loguniform ('bagging_temperature' , 0.01 , 100.00 ),
58+ 'od_type' : trial .suggest_categorical ('od_type' , ['IncToDec' , 'Iter' ]),
59+ 'od_wait' : trial .suggest_int ('od_wait' , 10 , 50 )
60+ }
61+
62+ return params
63+
3164class _OptunaParamFactory (metaclass = Singleton ):
3265 def __init__ (self ):
3366 self ._rules = dict ()
3467 self ._rules ['XGBClassifier_default' ] = _xgbclassifier_default
68+ self ._rules ['LGBMClassifier_default' ] = _lgbmclassifier_default
69+ self ._rules ['CatBoostClassifier_default' ] = _catboostclassifier_default
3570
3671 def get (self , key : str , trial : optuna .trial .Trial ):
3772 if key not in self ._rules :
0 commit comments