11import argparse
22
3+ from keras import Model
34from keras .callbacks import EarlyStopping , ReduceLROnPlateau
45
56from seg_tgce .data .crowd_seg .tfds_builder import (
89 get_processed_data ,
910)
1011from seg_tgce .experiments .plot_utils import plot_training_history , print_test_metrics
12+ from seg_tgce .experiments .types import HpTunerTrial
1113from seg_tgce .models .builders import build_features_model_from_hparams
1214from seg_tgce .models .ma_model import FeatureVisualizationCallback
1315
14- from ..utils import handle_training
16+ from ..utils import handle_training_optuna
1517
16- TARGET_SHAPE = (128 , 128 )
18+ TARGET_SHAPE = (256 , 256 )
1719BATCH_SIZE = 4
1820TRAIN_EPOCHS = 20
1921TUNER_EPOCHS = 1
22+ TUNER_MAX_TRIALS = 10
23+ STUDY_NAME = "histology_features_tuning"
24+ OBJECTIVE = "val_segmentation_output_dice_coefficient"
2025
2126DEFAULT_HPARAMS = {
2227 "initial_learning_rate" : 1e-3 ,
2328 "q" : 0.5 ,
2429 "noise_tolerance" : 0.5 ,
2530 "a" : 0.5 ,
2631 "b" : 0.5 ,
27- "c" : 1.0 ,
2832 "lambda_reg_weight" : 0.1 ,
2933 "lambda_entropy_weight" : 0.1 ,
3034 "lambda_sum_weight" : 0.1 ,
3135}
3236
3337
34- def build_model (hp = None ):
35- if hp is None :
36- params = DEFAULT_HPARAMS
37- else :
38- params = {
39- "initial_learning_rate" : hp .Float (
40- "learning_rate" , min_value = 1e-5 , max_value = 1e-2 , sampling = "LOG"
41- ),
42- "q" : hp .Float ("q" , min_value = 0.1 , max_value = 0.9 , step = 0.1 ),
43- "noise_tolerance" : hp .Float (
44- "noise_tolerance" , min_value = 0.1 , max_value = 0.9 , step = 0.1
45- ),
46- "lambda_reg_weight" : hp .Float (
47- "lambda_reg_weight" , min_value = 0.01 , max_value = 0.5 , step = 0.01
48- ),
49- "lambda_entropy_weight" : hp .Float (
50- "lambda_entropy_weight" , min_value = 0.01 , max_value = 0.5 , step = 0.01
51- ),
52- "lambda_sum_weight" : hp .Float (
53- "lambda_sum_weight" , min_value = 0.01 , max_value = 0.5 , step = 0.01
54- ),
55- "a" : hp .Float ("a" , min_value = 0.0 , max_value = 1.0 , step = 0.1 ),
56- "b" : hp .Float ("b" , min_value = 0.0 , max_value = 1.0 , step = 0.1 ),
57- "c" : hp .Float ("c" , min_value = 0.0 , max_value = 1.0 , step = 0.1 ),
58- }
38+ def build_model_from_trial (trial : HpTunerTrial | None ) -> Model :
39+ if trial is None :
40+ return build_features_model_from_hparams (
41+ learning_rate = DEFAULT_HPARAMS ["initial_learning_rate" ],
42+ q = DEFAULT_HPARAMS ["q" ],
43+ noise_tolerance = DEFAULT_HPARAMS ["noise_tolerance" ],
44+ b = DEFAULT_HPARAMS ["b" ],
45+ a = DEFAULT_HPARAMS ["a" ],
46+ lambda_reg_weight = DEFAULT_HPARAMS ["lambda_reg_weight" ],
47+ lambda_entropy_weight = DEFAULT_HPARAMS ["lambda_entropy_weight" ],
48+ lambda_sum_weight = DEFAULT_HPARAMS ["lambda_sum_weight" ],
49+ num_classes = N_CLASSES ,
50+ target_shape = TARGET_SHAPE ,
51+ n_scorers = N_REAL_SCORERS ,
52+ )
5953
6054 return build_features_model_from_hparams (
61- learning_rate = params ["initial_learning_rate" ],
62- q = params ["q" ],
63- noise_tolerance = params ["noise_tolerance" ],
64- a = params ["a" ],
65- b = params ["b" ],
66- c = params ["c" ],
67- lambda_reg_weight = params ["lambda_reg_weight" ],
68- lambda_entropy_weight = params ["lambda_entropy_weight" ],
69- lambda_sum_weight = params ["lambda_sum_weight" ],
55+ learning_rate = trial .suggest_float ("learning_rate" , 1e-5 , 1e-2 , log = True ),
56+ q = trial .suggest_float ("q" , 0.1 , 0.9 , step = 0.01 ),
57+ noise_tolerance = trial .suggest_float ("noise_tolerance" , 0.1 , 0.9 , step = 0.01 ),
58+ a = trial .suggest_float ("a" , 0.1 , 10.0 , step = 0.1 ),
59+ b = trial .suggest_float ("b" , 0.1 , 0.99 , step = 0.01 ),
60+ lambda_reg_weight = trial .suggest_float ("lambda_reg_weight" , 0.0 , 10.0 , step = 0.1 ),
61+ lambda_entropy_weight = trial .suggest_float (
62+ "lambda_entropy_weight" , 0.0 , 10.0 , step = 0.1
63+ ),
64+ lambda_sum_weight = trial .suggest_float ("lambda_sum_weight" , 0.0 , 10.0 , step = 0.1 ),
7065 num_classes = N_CLASSES ,
7166 target_shape = TARGET_SHAPE ,
7267 n_scorers = N_REAL_SCORERS ,
@@ -88,21 +83,21 @@ def build_model(hp=None):
8883 image_size = TARGET_SHAPE , batch_size = BATCH_SIZE , use_augmentation = False
8984 )
9085
91- model = handle_training (
86+ model = handle_training_optuna (
9287 processed_train ,
9388 processed_validation ,
94- model_builder = build_model ,
89+ model_builder = build_model_from_trial ,
9590 use_tuner = args .use_tuner ,
9691 tuner_epochs = TUNER_EPOCHS ,
97- objective = "val_segmentation_output_dice_coefficient" ,
92+ objective = OBJECTIVE ,
9893 )
9994
10095 vis_callback = FeatureVisualizationCallback (
10196 processed_validation , save_dir = "vis/histology/features"
10297 )
10398
10499 lr_scheduler = ReduceLROnPlateau (
105- monitor = "val_segmentation_output_dice_coefficient" ,
100+ monitor = OBJECTIVE ,
106101 factor = 0.5 ,
107102 patience = 3 ,
108103 min_lr = 1e-6 ,
@@ -120,7 +115,7 @@ def build_model(hp=None):
120115 vis_callback ,
121116 lr_scheduler ,
122117 EarlyStopping (
123- monitor = "val_segmentation_output_dice_coefficient" ,
118+ monitor = OBJECTIVE ,
124119 patience = 5 ,
125120 mode = "max" ,
126121 restore_best_weights = True ,
0 commit comments