Skip to content

Commit da6682b

Browse files
committed
Enable test HyperDT with bank data
1 parent b5bdcfd commit da6682b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tests/models/hyper_dt_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,32 +18,32 @@
1818

1919

2020
class Test_HyperDT():
21-
def bankdata(self):
21+
def test_bankdata(self):
2222
rs = RandomSearcher(mini_dt_space, optimize_direction=OptimizeDirection.Maximize, )
2323
hdt = HyperDT(rs,
2424
callbacks=[SummaryCallback(), FileLoggingCallback(rs, output_dir=f'{homedir}/hyn_logs')],
25-
reward_metric='accuracy',
26-
max_trials=3,
25+
# reward_metric='accuracy',
26+
reward_metric='AUC',
2727
dnn_params={
2828
'hidden_units': ((256, 0, False), (256, 0, False)),
2929
'dnn_activation': 'relu',
3030
},
3131
)
3232

33-
df = dsutils.load_bank()
33+
df = dsutils.load_bank().sample(frac=0.1, random_state=9527)
3434
df.drop(['id'], axis=1, inplace=True)
3535
df_train, df_test = train_test_split(df, test_size=0.2, random_state=42)
3636
y = df_train.pop('y')
3737
y_test = df_test.pop('y')
3838

39-
hdt.search(df_train, y, df_test, y_test)
39+
hdt.search(df_train, y, df_test, y_test, max_trials=3, )
4040
best_trial = hdt.get_best_trial()
4141
assert best_trial
4242

4343
estimator = hdt.final_train(best_trial.space_sample, df_train, y)
44-
score = estimator.predict(df)
45-
result = estimator.evaluate(df, y)
46-
assert len(score) == 100
44+
score = estimator.predict(df_test)
45+
result = estimator.evaluate(df_test, y_test)
46+
assert len(score) == len(y_test)
4747
assert result
4848
assert isinstance(estimator.model, DeepTable)
4949

0 commit comments

Comments
 (0)