1+ __author__ = "Animikh Aich"
2+ __copyright__ = "Copyright 2021, Animikh Aich"
3+ __credits__ = ["Animikh Aich" ]
4+ __license__ = "MIT"
5+ __version__ = "0.1.0"
6+ __maintainer__ = "Animikh Aich"
7+ 8+ __status__ = "staging"
9+
110import os
211
312os .environ ["TF_FORCE_GPU_ALLOW_GROWTH" ] = "true"
1221import pandas as pd
1322import plotly .graph_objs as go
1423
15- # TODO: Add Support For Dynamic Polt.ly Charts
1624# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
1725# TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images
26+ # TODO: Add Supoort For Experiment and Logs Tracking and Comparison to Past Experiments
27+ # TODO: Add Support For Dataset Visualization
28+ # TODO: Add Support for Augmented Batch Visualization
29+ # TODO: Add Support for Augmentation Hyperparameter Customization (More Granular Control)
30+
1831
32+ # Constant Values that are Pre-defined for the dashboard to function
1933OPTIMIZERS = {
2034 "SGD" : tf .keras .optimizers .SGD (),
2135 "RMSprop" : tf .keras .optimizers .RMSprop (),
3347 "Mixed Precision (TPU - BF16) " : "mixed_bfloat16" ,
3448}
3549
36-
3750LEARNING_RATES = [0.00001 , 0.0001 , 0.001 , 0.01 , 0.1 , 1 ]
3851
3952BATCH_SIZES = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
6073]
6174
6275
76+ st .title ("Zero Code Tensorflow Classifier Trainer" )
77+
78+
6379class CustomCallback (tf .keras .callbacks .Callback ):
80+ """
81+ CustomCallback Keras Callback to Send Updates to Streamlit Dashboard
82+
83+ - Inherits from tf.keras.callbacks.Callback class
84+ - Sends Live Updates to the Dashboard
85+ - Allows Plotting Live Loss and Accuracy Curves
86+ - Allows Updating of Progress bar to track batch progress
87+ - Live plot only support Epoch Loss & Accuracy to improve training speed
88+ """
89+
6490 def __init__ (self , num_steps ):
91+ """
92+ __init__
93+
94+ Value Initializations
95+
96+ Args:
97+ num_steps (int): Total Number of Steps per Epoch
98+ """
6599 self .num_steps = num_steps
66100
67101 # Constants (TODO: Need to Optimize)
@@ -80,6 +114,19 @@ def __init__(self, num_steps):
80114 self .accuracy_chart = st .empty ()
81115
82116 def update_graph (self , placeholder , items , title , xaxis , yaxis ):
117+ """
118+ update_graph Function to Update the plot.ly graphs on Streamlit
119+
120+ - Updates the Graphs Whenever called with the passed values
121+ - Only supports Line plots for now
122+
123+ Args:
124+ placeholder (st.empty()): streamlit placeholder object
125+ items (dict): Containing Name of the plot and values
126+ title (str): Title of the Plot
127+ xaxis (str): X-Axis Label
128+ yaxis (str): Y-Axis Label
129+ """
83130 fig = go .Figure ()
84131 for key in items .keys ():
85132 fig .add_trace (
@@ -93,24 +140,68 @@ def update_graph(self, placeholder, items, title, xaxis, yaxis):
93140 placeholder .write (fig )
94141
95142 def on_train_batch_end (self , batch , logs = None ):
143+ """
144+ on_train_batch_end Update Progress Bar
145+
146+ At the end of each Training Batch, Update the progress bar
147+
148+ Args:
149+ batch (int): Current batch number
150+ logs (dict, optional): Training Metrics. Defaults to None.
151+ """
96152 self .batch_progress .progress (batch / self .num_steps )
97153
98154 def on_epoch_begin (self , epoch , logs = None ):
155+ """
156+ on_epoch_begin
157+
158+ Update the Dashboard on the Current Epoch Number
159+
160+ Args:
161+ batch (int): Current batch number
162+ logs (dict, optional): Training Metrics. Defaults to None.
163+ """
99164 self .epoch_text .text (f"Epoch: { epoch + 1 } " )
100165
101166 def on_train_begin (self , logs = None ):
167+ """
168+ on_train_begin
169+
170+ Status Update for the Dashboard with a message that training has started
171+
172+ Args:
173+ batch (int): Current batch number
174+ logs (dict, optional): Training Metrics. Defaults to None.
175+ """
102176 self .status_text .info (
103177 "Training Started! Live Graphs will be shown on the completion of the First Epoch."
104178 )
105179
106180 def on_train_end (self , logs = None ):
181+ """
182+ on_train_end
183+
184+ Status Update for the Dashboard with a message that training has ended
185+
186+ Args:
187+ batch (int): Current batch number
188+ logs (dict, optional): Training Metrics. Defaults to None.
189+ """
107190 self .status_text .success (
108191 f"Training Completed! Final Validation Accuracy: { logs ['val_categorical_accuracy' ]* 100 :.2f} %"
109192 )
110193 st .balloons ()
111194
112195 def on_epoch_end (self , epoch , logs = None ):
196+ """
197+ on_epoch_end
198+
199+ Update the Graphs with the train & val loss & accuracy curves (metrics)
113200
201+ Args:
202+ batch (int): Current batch number
203+ logs (dict, optional): Training Metrics. Defaults to None.
204+ """
114205 self .train_losses .append (logs ["loss" ])
115206 self .val_losses .append (logs ["val_loss" ])
116207 self .train_accuracies .append (logs ["categorical_accuracy" ])
@@ -136,8 +227,7 @@ def on_epoch_end(self, epoch, logs=None):
136227 )
137228
138229
139- st .title ("Zero Code Tensorflow Classifier Trainer" )
140-
230+ # Sidebar Configuration Parameters
141231with st .sidebar :
142232 st .header ("Training Configuration" )
143233
@@ -158,7 +248,7 @@ def on_epoch_end(self, epoch, logs=None):
158248 selected_optimizer = st .selectbox ("Training Optimizer" , list (OPTIMIZERS .keys ()))
159249
160250 # Select Learning Rate
161- selected_learning_rate = st .select_slider ("Learning Rate" , LEARNING_RATES , 0.01 )
251+ selected_learning_rate = st .select_slider ("Learning Rate" , LEARNING_RATES , 0.001 )
162252
163253 # Select Batch Size
164254 selected_batch_size = st .select_slider ("Train/Eval Batch Size" , BATCH_SIZES , 16 )
@@ -177,44 +267,54 @@ def on_epoch_end(self, epoch, logs=None):
177267 # Start Training Button
178268 start_training = st .button ("Start Training" )
179269
270+ # If the Button is pressed, start Training
180271if start_training :
272+ # Init the Input Shape for the Image
181273 input_shape = (selected_input_shape , selected_input_shape , 3 )
182274
275+ # Init Training Data Loader
183276 train_data_loader = ImageClassificationDataLoader (
184277 data_dir = train_data_dir ,
185278 image_dims = input_shape [:2 ],
186279 grayscale = False ,
187280 num_min_samples = 100 ,
188281 )
189282
283+ # Init Validation Data Loader
190284 val_data_loader = ImageClassificationDataLoader (
191285 data_dir = val_data_dir ,
192286 image_dims = input_shape [:2 ],
193287 grayscale = False ,
194288 num_min_samples = 100 ,
195289 )
196290
291+ # Get Training & Validation Dataset Generators
197292 train_generator = train_data_loader .dataset_generator (
198293 batch_size = selected_batch_size , augment = True
199294 )
200295 val_generator = val_data_loader .dataset_generator (
201296 batch_size = selected_batch_size , augment = False
202297 )
203298
299+ # Set the Learning Rate for the Selected Optimizer
204300 OPTIMIZERS [selected_optimizer ].learning_rate .assign (selected_learning_rate )
205301
302+ # Init the Classification Trainier
206303 classifier = ImageClassifier (
207304 backbone = selected_backbone ,
208305 input_shape = input_shape ,
209306 classes = train_data_loader .get_num_classes (),
210307 optimizer = OPTIMIZERS [selected_optimizer ],
211308 )
212309
310+ # Set the Callbacks to include the custom callback (to stream progress to dashboard)
213311 classifier .init_callbacks (
214312 [CustomCallback (train_data_loader .get_num_steps ())],
215313 )
314+ # Enable or Disable Mixed Precision Training
216315 classifier .set_precision (TRAINING_PRECISION [selected_precision ])
217316
317+ # Start Training
218318 classifier .train (
219319 train_generator ,
220320 train_data_loader .get_num_steps (),
0 commit comments