Skip to content

Commit bde422d

Browse files
committed
core\refac: #98 oxford pet prefetch optimization
- optimized dataset prefetching and properly display best hps in optuna tuning
1 parent 0191f9e commit bde422d

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

core/seg_tgce/data/oxford_pet/oxford_pet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_data_multiple_annotators(
7777
)
7878
for data, labeler_manager in (
7979
(
80-
train_dataset.cache().shuffle(1000).repeat(2).prefetch(tf.data.AUTOTUNE),
80+
train_dataset.shuffle(1000).prefetch(tf.data.AUTOTUNE),
8181
train_labeler_manager,
8282
),
8383
(val_dataset.prefetch(tf.data.AUTOTUNE), None),

core/seg_tgce/experiments/pets/scalar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
from seg_tgce.models.builders import build_scalar_model_from_hparams
1414
from seg_tgce.models.ma_model import ScalarVisualizationCallback
1515

16-
TARGET_SHAPE = (64, 64)
16+
TARGET_SHAPE = (128, 128)
1717
BATCH_SIZE = 16
1818
NUM_CLASSES = 3
1919
NOISE_LEVELS = [-20.0, 10.0]
2020
NUM_SCORERS = len(NOISE_LEVELS)
2121
TRAIN_EPOCHS = 50
22-
TUNER_EPOCHS = 10
22+
TUNER_EPOCHS = 5
2323
LABELING_RATE = 0.5
24-
TUNER_MAX_TRIALS = 10
24+
TUNER_MAX_TRIALS = 3
2525
STUDY_NAME = "pets_scalar_tuning"
2626
OBJECTIVE = "val_segmentation_output_dice_coefficient"
2727
DEFAULT_HPARAMS = {
@@ -78,8 +78,8 @@ def build_model_from_trial(trial: HpTunerTrial | None) -> Model:
7878
)
7979

8080
model = handle_training_optuna(
81-
train,
82-
val,
81+
train.take(10).cache(),
82+
val.take(10).cache(),
8383
model_builder=build_model_from_trial,
8484
use_tuner=args.use_tuner,
8585
tuner_epochs=TUNER_EPOCHS,

core/seg_tgce/experiments/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def train_with_tuner(train_gen: tf.data.Dataset, val_gen: tf.data.Dataset) -> Mo
3636

3737
print("Starting hyperparameter search...")
3838
tuner.search(
39-
train_gen.take(10),
39+
train_gen,
4040
epochs=tuner_epochs,
4141
validation_data=val_gen,
4242
)
@@ -134,7 +134,7 @@ def _objective(trial: optuna.Trial) -> float:
134134

135135
best_hps = study.best_trial
136136
print("\nBest hyperparameters:")
137-
for param, value in best_hps.values.items():
137+
for param, value in best_hps.params.items():
138138
print(f"{param}: {value}")
139139

140140
create_importance_visualizations(study)

0 commit comments

Comments
 (0)