|
1 | | -import keras_tuner as kt |
2 | | -import tensorflow as tf |
| 1 | +import argparse |
| 2 | + |
| 3 | +from keras import Model |
| 4 | +from keras.callbacks import EarlyStopping, ReduceLROnPlateau |
3 | 5 |
|
4 | 6 | from seg_tgce.data.oxford_pet.oxford_pet import ( |
5 | 7 | fetch_models, |
6 | 8 | get_data_multiple_annotators, |
7 | 9 | ) |
8 | 10 | from seg_tgce.experiments.plot_utils import plot_training_history, print_test_metrics |
| 11 | +from seg_tgce.experiments.types import HpTunerTrial |
| 12 | +from seg_tgce.experiments.utils import handle_training_optuna |
9 | 13 | from seg_tgce.models.builders import build_pixel_model_from_hparams |
10 | 14 | from seg_tgce.models.ma_model import PixelVisualizationCallback |
11 | 15 |
|
|
16 | 20 | NUM_SCORERS = len(NOISE_LEVELS) |
17 | 21 | TRAIN_EPOCHS = 50 |
18 | 22 | TUNER_EPOCHS = 1 |
19 | | -TUNER_TRIALS = 1 |
| 23 | +TUNER_MAX_TRIALS = 1 |
| 24 | +STUDY_NAME = "pets_pixel_tuning" |
| 25 | +OBJECTIVE = "val_segmentation_output_dice_coefficient" |
| 26 | +LABELING_RATE = 1.0 |
20 | 27 |
|
| 28 | +DEFAULT_HPARAMS = { |
| 29 | + "initial_learning_rate": 1e-3, |
| 30 | + "q": 0.7, |
| 31 | + "noise_tolerance": 0.5, |
| 32 | + "a": 0.2, |
| 33 | + "b": 0.7, |
| 34 | + "c": 1.0, |
| 35 | + "lambda_reg_weight": 0.1, |
| 36 | + "lambda_entropy_weight": 0.1, |
| 37 | + "lambda_sum_weight": 0.1, |
| 38 | +} |
21 | 39 |
|
22 | | -def build_model(hp: kt.HyperParameters) -> tf.keras.Model: |
23 | | - learning_rate = hp.Float( |
24 | | - "learning_rate", min_value=1e-5, max_value=1e-2, sampling="LOG" |
25 | | - ) |
26 | | - q = hp.Float("q", min_value=0.1, max_value=0.9, step=0.1) |
27 | | - noise_tolerance = hp.Float("noise_tolerance", min_value=0.1, max_value=0.9, step=0.1) |
28 | | - lambda_reg_weight = hp.Float( |
29 | | - "lambda_reg_weight", min_value=0.01, max_value=0.5, step=0.01 |
30 | | - ) |
31 | | - lambda_entropy_weight = hp.Float( |
32 | | - "lambda_entropy_weight", min_value=0.01, max_value=0.5, step=0.01 |
33 | | - ) |
34 | | - lambda_sum_weight = hp.Float( |
35 | | - "lambda_sum_weight", min_value=0.01, max_value=0.5, step=0.01 |
36 | | - ) |
| 40 | + |
| 41 | +def build_model_from_trial(trial: HpTunerTrial | None) -> Model: |
| 42 | + if trial is None: |
| 43 | + return build_pixel_model_from_hparams( |
| 44 | + learning_rate=DEFAULT_HPARAMS["initial_learning_rate"], |
| 45 | + q=DEFAULT_HPARAMS["q"], |
| 46 | + noise_tolerance=DEFAULT_HPARAMS["noise_tolerance"], |
| 47 | + b=DEFAULT_HPARAMS["b"], |
| 48 | + c=DEFAULT_HPARAMS["c"], |
| 49 | + a=DEFAULT_HPARAMS["a"], |
| 50 | + lambda_reg_weight=DEFAULT_HPARAMS["lambda_reg_weight"], |
| 51 | + lambda_entropy_weight=DEFAULT_HPARAMS["lambda_entropy_weight"], |
| 52 | + lambda_sum_weight=DEFAULT_HPARAMS["lambda_sum_weight"], |
| 53 | + num_classes=NUM_CLASSES, |
| 54 | + target_shape=TARGET_SHAPE, |
| 55 | + n_scorers=NUM_SCORERS, |
| 56 | + ) |
37 | 57 |
|
38 | 58 | return build_pixel_model_from_hparams( |
39 | | - learning_rate=learning_rate, |
40 | | - q=q, |
41 | | - noise_tolerance=noise_tolerance, |
42 | | - lambda_reg_weight=lambda_reg_weight, |
43 | | - lambda_entropy_weight=lambda_entropy_weight, |
44 | | - lambda_sum_weight=lambda_sum_weight, |
| 59 | + learning_rate=trial.suggest_float("learning_rate", 1e-5, 1e-2, log=True), |
| 60 | + q=trial.suggest_float("q", 0.1, 0.9, step=0.01), |
| 61 | + noise_tolerance=trial.suggest_float("noise_tolerance", 0.1, 0.9, step=0.01), |
| 62 | + b=trial.suggest_float("b", 0.1, 1.0, step=0.01), |
| 63 | + a=trial.suggest_float("a", 0.1, 1.0, step=0.01), |
| 64 | + c=trial.suggest_float("c", 0.1, 10.0, step=0.1), |
| 65 | + lambda_reg_weight=trial.suggest_float("lambda_reg_weight", 0.0, 10.0, step=0.1), |
| 66 | + lambda_entropy_weight=trial.suggest_float( |
| 67 | + "lambda_entropy_weight", 0.0, 10.0, step=0.1 |
| 68 | + ), |
| 69 | + lambda_sum_weight=trial.suggest_float("lambda_sum_weight", 0.0, 10.0, step=0.1), |
45 | 70 | num_classes=NUM_CLASSES, |
46 | 71 | target_shape=TARGET_SHAPE, |
47 | 72 | n_scorers=NUM_SCORERS, |
48 | 73 | ) |
49 | 74 |
|
50 | 75 |
|
51 | 76 | if __name__ == "__main__": |
| 77 | + parser = argparse.ArgumentParser( |
| 78 | + description="Train pets pixel model with or without hyperparameter tuning" |
| 79 | + ) |
| 80 | + parser.add_argument( |
| 81 | + "--use-tuner", |
| 82 | + action="store_true", |
| 83 | + help="Use Keras Tuner for hyperparameter optimization", |
| 84 | + ) |
| 85 | + args = parser.parse_args() |
| 86 | + |
52 | 87 | disturbance_models = fetch_models(NOISE_LEVELS) |
53 | 88 | train, val, test = get_data_multiple_annotators( |
54 | 89 | annotation_models=disturbance_models, |
55 | 90 | target_shape=TARGET_SHAPE, |
56 | 91 | batch_size=BATCH_SIZE, |
57 | | - labeling_rate=0.5, |
| 92 | + labeling_rate=LABELING_RATE, |
58 | 93 | ) |
59 | 94 |
|
60 | | - tuner = kt.BayesianOptimization( |
61 | | - build_model, |
62 | | - objective=kt.Objective( |
63 | | - "val_segmentation_output_dice_coefficient", direction="max" |
64 | | - ), |
65 | | - max_trials=TUNER_TRIALS, |
66 | | - directory="tuner_results", |
67 | | - project_name="pixel_tuning", |
| 95 | + model = handle_training_optuna( |
| 96 | + train.take(10).cache(), |
| 97 | + val.take(10).cache(), |
| 98 | + model_builder=build_model_from_trial, |
| 99 | + use_tuner=args.use_tuner, |
| 100 | + tuner_epochs=TUNER_EPOCHS, |
| 101 | + objective=OBJECTIVE, |
| 102 | + tuner_max_trials=TUNER_MAX_TRIALS, |
| 103 | + study_name=STUDY_NAME, |
68 | 104 | ) |
69 | 105 |
|
70 | | - print("Starting hyperparameter search...") |
71 | | - tuner.search( |
72 | | - train.take(16).cache(), |
73 | | - epochs=TUNER_EPOCHS, |
74 | | - validation_data=val.take(8).cache(), |
75 | | - ) |
| 106 | + vis_callback = PixelVisualizationCallback(val, save_dir="vis/pets/pixel") |
76 | 107 |
|
77 | | - best_hps = tuner.get_best_hyperparameters(num_trials=1)[0] |
78 | | - print("\nBest hyperparameters:") |
79 | | - for param, value in best_hps.values.items(): |
80 | | - print(f"{param}: {value}") |
81 | | - |
82 | | - model = build_model(best_hps) |
83 | | - vis_callback = PixelVisualizationCallback(val) |
| 108 | + lr_scheduler = ReduceLROnPlateau( |
| 109 | + monitor=OBJECTIVE, |
| 110 | + factor=0.5, |
| 111 | + patience=3, |
| 112 | + min_lr=1e-6, |
| 113 | + mode="max", |
| 114 | + verbose=1, |
| 115 | + ) |
84 | 116 |
|
85 | | - print("\nTraining with best hyperparameters...") |
| 117 | + print("\nTraining final model...") |
86 | 118 | history = model.fit( |
87 | | - train.take(16).cache(), |
| 119 | + train, |
88 | 120 | epochs=TRAIN_EPOCHS, |
89 | | - validation_data=val.take(8).cache(), |
90 | | - callbacks=[vis_callback], |
| 121 | + validation_data=val.cache(), |
| 122 | + callbacks=[ |
| 123 | + vis_callback, |
| 124 | + lr_scheduler, |
| 125 | + EarlyStopping( |
| 126 | + monitor=OBJECTIVE, |
| 127 | + patience=5, |
| 128 | + mode="max", |
| 129 | + restore_best_weights=True, |
| 130 | + ), |
| 131 | + ], |
91 | 132 | ) |
92 | 133 |
|
93 | | - plot_training_history(history, "Pixel Model Training History") |
94 | | - |
95 | | - print_test_metrics(model, test, "Pixel") |
| 134 | + plot_training_history(history, "Pets Pixel Model Training History") |
| 135 | + print_test_metrics(model, test, "Pets Pixel") |
0 commit comments