Skip to content

Commit 4dc0084

Browse files
authored
Merge pull request #53 from lixfz/master
Enable test HyperDT with bank data
2 parents 754e288 + 2e8101e commit 4dc0084

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ script:
2828
notifications:
2929
recipients:
3030
31+
3132

3233
on_success: change
3334
on_failure: change

tests/models/hyper_dt_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,32 @@
1818

1919

2020
class Test_HyperDT():
21-
def bankdata(self):
21+
def test_bankdata(self):
2222
rs = RandomSearcher(mini_dt_space, optimize_direction=OptimizeDirection.Maximize, )
2323
hdt = HyperDT(rs,
2424
callbacks=[SummaryCallback(), FileLoggingCallback(rs, output_dir=f'{homedir}/hyn_logs')],
25-
reward_metric='accuracy',
26-
max_trials=3,
25+
# reward_metric='accuracy',
26+
reward_metric='AUC',
2727
dnn_params={
2828
'hidden_units': ((256, 0, False), (256, 0, False)),
2929
'dnn_activation': 'relu',
3030
},
3131
)
3232

33-
df = dsutils.load_bank()
33+
df = dsutils.load_bank().sample(frac=0.1, random_state=9527)
3434
df.drop(['id'], axis=1, inplace=True)
3535
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
3636
y = df_train.pop('y')
3737
y_test = df_test.pop('y')
3838

39-
hdt.search(df_train, y, df_test, y_test)
39+
hdt.search(df_train, y, df_test, y_test, max_trials=3, )
4040
best_trial = hdt.get_best_trial()
4141
assert best_trial
4242

4343
estimator = hdt.final_train(best_trial.space_sample, df_train, y)
44-
score = estimator.predict(df)
45-
result = estimator.evaluate(df, y)
46-
assert len(score) == 100
44+
score = estimator.predict(df_test)
45+
result = estimator.evaluate(df_test, y_test)
46+
assert len(score) == len(y_test)
4747
assert result
4848
assert isinstance(estimator.model, DeepTable)
4949

0 commit comments

Comments
 (0)