Skip to content

Commit 84d3a9f

Browse files
committed
Added: Support for Dynamic Learning Rate
1 parent e1f05ce commit 84d3a9f

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pandas as pd
1313
import plotly.graph_objs as go
1414

15-
# TODO: Add Support For Learning Rate Change
1615
# TODO: Add Support For Dynamic Polt.ly Charts
1716
# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
1817
# TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images
@@ -35,6 +34,8 @@
3534
}
3635

3736

37+
LEARNING_RATES = [0.00001, 0.0001, 0.001, 0.01, 0.1, 1]
38+
3839
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256]
3940

4041
BACKBONES = [
@@ -156,6 +157,9 @@ def on_epoch_end(self, epoch, logs=None):
156157
# Select Optimizer
157158
selected_optimizer = st.selectbox("Training Optimizer", list(OPTIMIZERS.keys()))
158159

160+
# Select Learning Rate
161+
selected_learning_rate = st.select_slider("Learning Rate", LEARNING_RATES, 0.01)
162+
159163
# Select Batch Size
160164
selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16)
161165

@@ -197,11 +201,13 @@ def on_epoch_end(self, epoch, logs=None):
197201
batch_size=selected_batch_size, augment=False
198202
)
199203

204+
OPTIMIZERS[selected_optimizer].learning_rate.assign(selected_learning_rate)
205+
200206
classifier = ImageClassifier(
201207
backbone=selected_backbone,
202208
input_shape=input_shape,
203209
classes=train_data_loader.get_num_classes(),
204-
optimizer=selected_optimizer,
210+
optimizer=OPTIMIZERS[selected_optimizer],
205211
)
206212

207213
classifier.init_callbacks(

0 commit comments

Comments
 (0)