Skip to content

Commit dc5e0e3

Browse files
committed
Added Docs to main file
1 parent 84d3a9f commit dc5e0e3

File tree

1 file changed

+105
-5
lines changed

1 file changed

+105
-5
lines changed

main.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
__author__ = "Animikh Aich"
2+
__copyright__ = "Copyright 2021, Animikh Aich"
3+
__credits__ = ["Animikh Aich"]
4+
__license__ = "MIT"
5+
__version__ = "0.1.0"
6+
__maintainer__ = "Animikh Aich"
7+
__email__ = "[email protected]"
8+
__status__ = "staging"
9+
110
import os
211

312
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
@@ -12,10 +21,15 @@
1221
import pandas as pd
1322
import plotly.graph_objs as go
1423

15-
# TODO: Add Support For Dynamic Polt.ly Charts
1624
# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
1725
# TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images
26+
# TODO: Add Supoort For Experiment and Logs Tracking and Comparison to Past Experiments
27+
# TODO: Add Support For Dataset Visualization
28+
# TODO: Add Support for Augmented Batch Visualization
29+
# TODO: Add Support for Augmentation Hyperparameter Customization (More Granular Control)
30+
1831

32+
# Constant Values that are Pre-defined for the dashboard to function
1933
OPTIMIZERS = {
2034
"SGD": tf.keras.optimizers.SGD(),
2135
"RMSprop": tf.keras.optimizers.RMSprop(),
@@ -33,7 +47,6 @@
3347
"Mixed Precision (TPU - BF16) ": "mixed_bfloat16",
3448
}
3549

36-
3750
LEARNING_RATES = [0.00001, 0.0001, 0.001, 0.01, 0.1, 1]
3851

3952
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256]
@@ -60,8 +73,29 @@
6073
]
6174

6275

76+
st.title("Zero Code Tensorflow Classifier Trainer")
77+
78+
6379
class CustomCallback(tf.keras.callbacks.Callback):
80+
"""
81+
CustomCallback Keras Callback to Send Updates to Streamlit Dashboard
82+
83+
- Inherits from tf.keras.callbacks.Callback class
84+
- Sends Live Updates to the Dashboard
85+
- Allows Plotting Live Loss and Accuracy Curves
86+
- Allows Updating of Progress bar to track batch progress
87+
- Live plot only support Epoch Loss & Accuracy to improve training speed
88+
"""
89+
6490
def __init__(self, num_steps):
91+
"""
92+
__init__
93+
94+
Value Initializations
95+
96+
Args:
97+
num_steps (int): Total Number of Steps per Epoch
98+
"""
6599
self.num_steps = num_steps
66100

67101
# Constants (TODO: Need to Optimize)
@@ -80,6 +114,19 @@ def __init__(self, num_steps):
80114
self.accuracy_chart = st.empty()
81115

82116
def update_graph(self, placeholder, items, title, xaxis, yaxis):
117+
"""
118+
update_graph Function to Update the plot.ly graphs on Streamlit
119+
120+
- Updates the Graphs Whenever called with the passed values
121+
- Only supports Line plots for now
122+
123+
Args:
124+
placeholder (st.empty()): streamlit placeholder object
125+
items (dict): Containing Name of the plot and values
126+
title (str): Title of the Plot
127+
xaxis (str): X-Axis Label
128+
yaxis (str): Y-Axis Label
129+
"""
83130
fig = go.Figure()
84131
for key in items.keys():
85132
fig.add_trace(
@@ -93,24 +140,68 @@ def update_graph(self, placeholder, items, title, xaxis, yaxis):
93140
placeholder.write(fig)
94141

95142
def on_train_batch_end(self, batch, logs=None):
143+
"""
144+
on_train_batch_end Update Progress Bar
145+
146+
At the end of each Training Batch, Update the progress bar
147+
148+
Args:
149+
batch (int): Current batch number
150+
logs (dict, optional): Training Metrics. Defaults to None.
151+
"""
96152
self.batch_progress.progress(batch / self.num_steps)
97153

98154
def on_epoch_begin(self, epoch, logs=None):
155+
"""
156+
on_epoch_begin
157+
158+
Update the Dashboard on the Current Epoch Number
159+
160+
Args:
161+
batch (int): Current batch number
162+
logs (dict, optional): Training Metrics. Defaults to None.
163+
"""
99164
self.epoch_text.text(f"Epoch: {epoch + 1}")
100165

101166
def on_train_begin(self, logs=None):
167+
"""
168+
on_train_begin
169+
170+
Status Update for the Dashboard with a message that training has started
171+
172+
Args:
173+
batch (int): Current batch number
174+
logs (dict, optional): Training Metrics. Defaults to None.
175+
"""
102176
self.status_text.info(
103177
"Training Started! Live Graphs will be shown on the completion of the First Epoch."
104178
)
105179

106180
def on_train_end(self, logs=None):
181+
"""
182+
on_train_end
183+
184+
Status Update for the Dashboard with a message that training has ended
185+
186+
Args:
187+
batch (int): Current batch number
188+
logs (dict, optional): Training Metrics. Defaults to None.
189+
"""
107190
self.status_text.success(
108191
f"Training Completed! Final Validation Accuracy: {logs['val_categorical_accuracy']*100:.2f}%"
109192
)
110193
st.balloons()
111194

112195
def on_epoch_end(self, epoch, logs=None):
196+
"""
197+
on_epoch_end
198+
199+
Update the Graphs with the train & val loss & accuracy curves (metrics)
113200
201+
Args:
202+
batch (int): Current batch number
203+
logs (dict, optional): Training Metrics. Defaults to None.
204+
"""
114205
self.train_losses.append(logs["loss"])
115206
self.val_losses.append(logs["val_loss"])
116207
self.train_accuracies.append(logs["categorical_accuracy"])
@@ -136,8 +227,7 @@ def on_epoch_end(self, epoch, logs=None):
136227
)
137228

138229

139-
st.title("Zero Code Tensorflow Classifier Trainer")
140-
230+
# Sidebar Configuration Parameters
141231
with st.sidebar:
142232
st.header("Training Configuration")
143233

@@ -158,7 +248,7 @@ def on_epoch_end(self, epoch, logs=None):
158248
selected_optimizer = st.selectbox("Training Optimizer", list(OPTIMIZERS.keys()))
159249

160250
# Select Learning Rate
161-
selected_learning_rate = st.select_slider("Learning Rate", LEARNING_RATES, 0.01)
251+
selected_learning_rate = st.select_slider("Learning Rate", LEARNING_RATES, 0.001)
162252

163253
# Select Batch Size
164254
selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16)
@@ -177,44 +267,54 @@ def on_epoch_end(self, epoch, logs=None):
177267
# Start Training Button
178268
start_training = st.button("Start Training")
179269

270+
# If the Button is pressed, start Training
180271
if start_training:
272+
# Init the Input Shape for the Image
181273
input_shape = (selected_input_shape, selected_input_shape, 3)
182274

275+
# Init Training Data Loader
183276
train_data_loader = ImageClassificationDataLoader(
184277
data_dir=train_data_dir,
185278
image_dims=input_shape[:2],
186279
grayscale=False,
187280
num_min_samples=100,
188281
)
189282

283+
# Init Validation Data Loader
190284
val_data_loader = ImageClassificationDataLoader(
191285
data_dir=val_data_dir,
192286
image_dims=input_shape[:2],
193287
grayscale=False,
194288
num_min_samples=100,
195289
)
196290

291+
# Get Training & Validation Dataset Generators
197292
train_generator = train_data_loader.dataset_generator(
198293
batch_size=selected_batch_size, augment=True
199294
)
200295
val_generator = val_data_loader.dataset_generator(
201296
batch_size=selected_batch_size, augment=False
202297
)
203298

299+
# Set the Learning Rate for the Selected Optimizer
204300
OPTIMIZERS[selected_optimizer].learning_rate.assign(selected_learning_rate)
205301

302+
# Init the Classification Trainier
206303
classifier = ImageClassifier(
207304
backbone=selected_backbone,
208305
input_shape=input_shape,
209306
classes=train_data_loader.get_num_classes(),
210307
optimizer=OPTIMIZERS[selected_optimizer],
211308
)
212309

310+
# Set the Callbacks to include the custom callback (to stream progress to dashboard)
213311
classifier.init_callbacks(
214312
[CustomCallback(train_data_loader.get_num_steps())],
215313
)
314+
# Enable or Disable Mixed Precision Training
216315
classifier.set_precision(TRAINING_PRECISION[selected_precision])
217316

317+
# Start Training
218318
classifier.train(
219319
train_generator,
220320
train_data_loader.get_num_steps(),

0 commit comments

Comments
 (0)