Skip to content

Commit 13a770f

Browse files
committed
Update HyperDT
1 parent d1f3999 commit 13a770f

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

deeptables/models/hyper_dt.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
55
"""
66
import copy
7-
import pickle
87

98
import pandas as pd
9+
import pickle
1010

1111
from deeptables.models.config import ModelConfig
1212
from deeptables.models.deeptable import DeepTable
@@ -20,6 +20,12 @@
2020
logger = dt_logging.get_logger(__name__)
2121

2222

23+
def _to_hp(v):
24+
if isinstance(v, (list, tuple)):
25+
v = Choice(v)
26+
return v
27+
28+
2329
class DTModuleSpace(ModuleSpace):
2430
def __init__(self, space=None, name=None, **hyperparams):
2531
ModuleSpace.__init__(self, space, name, **hyperparams)
@@ -37,13 +43,16 @@ def _on_params_ready(self):
3743

3844

3945
class DTFit(ModuleSpace):
40-
def __init__(self, batch_size=None, epochs=None, space=None, name=None, **hyperparams):
41-
if batch_size is None:
42-
batch_size = Choice([128, 256])
43-
hyperparams['batch_size'] = batch_size
46+
def __init__(self, space=None, name=None, **hyperparams):
47+
# if batch_size is None:
48+
# batch_size = Choice([128, 256])
49+
# hyperparams['batch_size'] = batch_size
50+
#
51+
# if epochs is not None:
52+
# hyperparams['epochs'] = epochs
4453

45-
if epochs is not None:
46-
hyperparams['epochs'] = epochs
54+
for k, v in hyperparams.items():
55+
hyperparams[k] = _to_hp(v)
4756

4857
ModuleSpace.__init__(self, space, name, **hyperparams)
4958
self.space.fit_params = self
@@ -62,24 +71,24 @@ class DnnModule(ModuleSpace):
6271
def __init__(self, hidden_units=None, reduce_factor=None, dnn_dropout=None, use_bn=None, dnn_layers=None,
6372
activation=None, space=None, name=None, **hyperparams):
6473
if hidden_units is None:
65-
hidden_units = Choice([100, 200, 300, 500, 800, 1000])
66-
hyperparams['hidden_units'] = hidden_units
74+
hidden_units = [100, 200, 300, 500, 800, 1000]
75+
hyperparams['hidden_units'] = _to_hp(hidden_units)
6776

6877
if reduce_factor is None:
69-
reduce_factor = Choice([1, 0.8, 0.5])
70-
hyperparams['reduce_factor'] = reduce_factor
78+
reduce_factor = [1, 0.8, 0.5]
79+
hyperparams['reduce_factor'] = _to_hp(reduce_factor)
7180

7281
if dnn_dropout is None:
73-
dnn_dropout = Choice([0, 0.1, 0.3, 0.5])
74-
hyperparams['dnn_dropout'] = dnn_dropout
82+
dnn_dropout = [0, 0.1, 0.3, 0.5]
83+
hyperparams['dnn_dropout'] = _to_hp(dnn_dropout)
7584

7685
if use_bn is None:
7786
use_bn = Bool()
7887
hyperparams['use_bn'] = use_bn
7988

8089
if dnn_layers is None:
81-
dnn_layers = Choice([1, 2, 3])
82-
hyperparams['dnn_layers'] = dnn_layers
90+
dnn_layers = [1, 2, 3]
91+
hyperparams['dnn_layers'] = _to_hp(dnn_layers)
8392

8493
if activation is None:
8594
activation = 'relu'
@@ -141,22 +150,25 @@ def summary(self):
141150
# logger.info(ex)
142151

143152
def fit(self, X, y, eval_set=None, pos_label=None, n_jobs=1, **kwargs):
144-
fit_params = self.space_sample.__dict__.get('fit_params')
145-
if fit_params is not None:
146-
kwargs.update(fit_params.param_values)
153+
# fit_params = self.space_sample.__dict__.get('fit_params')
154+
# if fit_params is not None:
155+
# kwargs.update(fit_params.param_values)
147156
if kwargs.get('cross_validation') is not None:
148157
kwargs.pop('cross_validation')
149158
self.model.fit_cross_validation(X, y, n_jobs=n_jobs, **kwargs)
150159
else:
151-
self.model.fit(X, y, **kwargs)
160+
fit_kwargs = self.space_sample.fit_params.param_values.copy()
161+
fit_kwargs.update(kwargs)
162+
self.model.fit(X, y, **fit_kwargs)
152163

153164
self.classes_ = getattr(self.model, 'classes_', None)
154165
return self
155166

156167
def fit_cross_validation(self, X, y, eval_set=None, metrics=None, pos_label=None, **kwargs):
157168
assert isinstance(metrics, (list, tuple))
158-
159-
oof_proba, _, _, oof_scores = self.model.fit_cross_validation(X, y, oof_metrics=metrics, **kwargs)
169+
fit_kwargs = self.space_sample.fit_params.param_values.copy()
170+
fit_kwargs.update(kwargs)
171+
oof_proba, _, _, oof_scores = self.model.fit_cross_validation(X, y, oof_metrics=metrics, **fit_kwargs)
160172

161173
# calc final score with mean
162174
scores = pd.concat([pd.Series(s) for s in oof_scores], axis=1).mean(axis=1).to_dict()
@@ -360,6 +372,7 @@ def tiny_dt_space(**hyperparams):
360372
use_bn=False,
361373
dnn_layers=2,
362374
activation='relu')(dt_module)
375+
hyperparams['batch_size'] = [64, 100]
363376
fit = DTFit(**hyperparams)(dt_module)
364377

365378
return space

0 commit comments

Comments
 (0)