|
18 | 18 |
|
19 | 19 |
|
20 | 20 | class Test_HyperDT(): |
21 | | - def bankdata(self): |
| 21 | + def test_bankdata(self): |
22 | 22 | rs = RandomSearcher(mini_dt_space, optimize_direction=OptimizeDirection.Maximize, ) |
23 | 23 | hdt = HyperDT(rs, |
24 | 24 | callbacks=[SummaryCallback(), FileLoggingCallback(rs, output_dir=f'{homedir}/hyn_logs')], |
25 | | - reward_metric='accuracy', |
26 | | - max_trials=3, |
| 25 | + # reward_metric='accuracy', |
| 26 | + reward_metric='AUC', |
27 | 27 | dnn_params={ |
28 | 28 | 'hidden_units': ((256, 0, False), (256, 0, False)), |
29 | 29 | 'dnn_activation': 'relu', |
30 | 30 | }, |
31 | 31 | ) |
32 | 32 |
|
33 | | - df = dsutils.load_bank() |
| 33 | + df = dsutils.load_bank().sample(frac=0.1, random_state=9527) |
34 | 34 | df.drop(['id'], axis=1, inplace=True) |
35 | 35 | df_train, df_test = train_test_split(df, test_size=0.2, random_state=42) |
36 | 36 | y = df_train.pop('y') |
37 | 37 | y_test = df_test.pop('y') |
38 | 38 |
|
39 | | - hdt.search(df_train, y, df_test, y_test) |
| 39 | + hdt.search(df_train, y, df_test, y_test, max_trials=3, ) |
40 | 40 | best_trial = hdt.get_best_trial() |
41 | 41 | assert best_trial |
42 | 42 |
|
43 | 43 | estimator = hdt.final_train(best_trial.space_sample, df_train, y) |
44 | | - score = estimator.predict(df) |
45 | | - result = estimator.evaluate(df, y) |
46 | | - assert len(score) == 100 |
| 44 | + score = estimator.predict(df_test) |
| 45 | + result = estimator.evaluate(df_test, y_test) |
| 46 | + assert len(score) == len(y_test) |
47 | 47 | assert result |
48 | 48 | assert isinstance(estimator.model, DeepTable) |
49 | 49 |
|
|
0 commit comments