Skip to content

Commit a666a3e

Browse files
committed
Updated Model Architecture, Removed EfficientNet Support
1 parent 1c17c0f commit a666a3e

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

main.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# TODO: Add Support For Learning Rate Change
1515
# TODO: Add Support For Dynamic Polt.ly Charts
1616
# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
17+
# TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images
1718

1819
OPTIMIZERS = {
1920
"SGD": tf.keras.optimizers.SGD(),
@@ -28,6 +29,27 @@
2829

2930
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256]
3031

32+
BACKBONES = [
33+
"MobileNet",
34+
"MobileNetV2",
35+
"ResNet50",
36+
"ResNet101",
37+
"ResNet152",
38+
"ResNet50V2",
39+
"ResNet101V2",
40+
"ResNet152V2",
41+
"VGG16",
42+
"VGG19",
43+
"Xception",
44+
"InceptionV3",
45+
"InceptionResNetV2",
46+
"DenseNet121",
47+
"DenseNet169",
48+
"DenseNet201",
49+
"NASNetMobile",
50+
"NASNetLarge",
51+
]
52+
3153

3254
class CustomCallback(tf.keras.callbacks.Callback):
3355
def __init__(self, num_steps):
@@ -69,11 +91,13 @@ def on_epoch_begin(self, epoch, logs=None):
6991

7092
def on_train_begin(self, logs=None):
7193
self.status_text.info(
72-
"Training Started! Live Graphs will be shown on the completion of the First Epoch"
94+
"Training Started! Live Graphs will be shown on the completion of the First Epoch."
7395
)
7496

7597
def on_train_end(self, logs=None):
76-
self.status_text.success("Training Completed!")
98+
self.status_text.success(
99+
f"Training Completed! Final Validation Accuracy: {logs['val_categorical_accuracy']*100:.2f}%"
100+
)
77101
st.balloons()
78102

79103
def on_epoch_end(self, epoch, logs=None):
@@ -126,14 +150,20 @@ def on_epoch_end(self, epoch, logs=None):
126150
"Tensorboard Logs Directory (Absolute Path)", "logs/tensorboard"
127151
)
128152

153+
# Select Backbone
154+
selected_backbone = st.selectbox("Select Backbone", BACKBONES)
155+
129156
# Select Optimizer
130157
selected_optimizer = st.selectbox("Training Optimizer", list(OPTIMIZERS.keys()))
131158

132159
# Select Batch Size
133160
selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16)
134161

135162
# Select Number of Epochs
136-
selected_epochs = st.number_input("Max Number of Epochs", 1, 300000, 100)
163+
selected_epochs = st.number_input("Max Number of Epochs", 1, 500, 100)
164+
165+
# Select Number of Epochs
166+
selected_input_shape = st.number_input("Input Image Shape", 64, 2000, 224)
137167

138168
# Start Training Button
139169
start_training = st.button("Start Training")
@@ -161,7 +191,7 @@ def on_epoch_end(self, epoch, logs=None):
161191
)
162192

163193
classifier = ImageClassifier(
164-
backbone="EfficientNetB0",
194+
backbone=selected_backbone,
165195
input_shape=(224, 224, 3),
166196
classes=train_data_loader.get_num_classes(),
167197
)
@@ -170,8 +200,6 @@ def on_epoch_end(self, epoch, logs=None):
170200
classifier.set_tensorboard_path(tensorboard_logs_path)
171201

172202
classifier.init_callbacks(
173-
keras_weights_path,
174-
tensorboard_logs_path,
175203
[CustomCallback(train_data_loader.get_num_steps())],
176204
)
177205

utils/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ def get_default_callbacks(self, weights_path, tb_logs_path):
9999
callbacks = [early_stop_cb, model_ckpt_cb, reduce_lr_cb, tensorboard_cb]
100100
return callbacks
101101

102-
def init_callbacks(self, weights_path, tb_logs_path, custom_callbacks=[]):
103-
self.callbacks = self.get_default_callbacks(weights_path, tb_logs_path)
102+
def init_callbacks(self, custom_callbacks=[]):
103+
self.callbacks = self.get_default_callbacks(
104+
self.keras_weights_path, self.tensorboard_logs_path
105+
)
104106
self.callbacks.extend(custom_callbacks)
105107
return self.callbacks
106108

@@ -180,8 +182,6 @@ def init_network(self):
180182
backbone=self.backbone, input_shape=self.input_shape, classes=self.classes
181183
)
182184
x = base.output
183-
x = tf.keras.layers.Conv2D(64, (1, 1), activation="relu")(x)
184-
x = tf.keras.layers.Dropout(0.5)(x)
185185
x = tf.keras.layers.Flatten()(x)
186186
x = tf.keras.layers.Dense(units=128, activation="relu")(x)
187187
x = tf.keras.layers.Dropout(0.5)(x)

0 commit comments

Comments
 (0)