99import streamlit as st
1010import numpy as np
1111import pandas as pd
12- import time
12+ import plotly . graph_objs as go
1313
1414# TODO: Add Support For Learning Rate Change
1515# TODO: Add Support For Dynamic Polt.ly Charts
16+ # TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
1617
1718OPTIMIZERS = {
1819 "SGD" : tf .keras .optimizers .SGD (),
2930
3031
3132class CustomCallback (tf .keras .callbacks .Callback ):
32- def __init__ (self , total_steps ):
33- self .total_steps = total_steps
34- self .loss_chart = st .line_chart (pd .DataFrame ({"Loss" : []}))
35- self .acc_precision_recall_chart = st .line_chart ()
36- self .batch_progress = st .progress (0 )
37-
38- super ().__init__ ()
33+ def __init__ (self , num_steps ):
34+ self .num_steps = num_steps
3935
40- def __stream_to_graph (self , chart_obj , values ):
41- chart_obj .add_rows (np .array ([values ]))
36+ # Constants (TODO: Need to Optimize)
37+ self .train_losses = []
38+ self .val_losses = []
39+ self .train_accuracies = []
40+ self .val_accuracies = []
4241
43- def __update_progress_bar (self , batch ):
44- current_progress = int (batch / self .total_steps * 100 )
45- self .batch_progress .progress (current_progress )
42+ # Progress
43+ self .epoch_text = st .empty ()
44+ self .batch_progress = st .progress (0 )
45+ self .status_text = st .empty ()
46+
47+ # Charts
48+ self .loss_chart = st .empty ()
49+ self .accuracy_chart = st .empty ()
50+
51+ def update_graph (self , placeholder , items , title , xaxis , yaxis ):
52+ fig = go .Figure ()
53+ for key in items .keys ():
54+ fig .add_trace (
55+ go .Scatter (
56+ y = items [key ],
57+ mode = "lines+markers" ,
58+ name = key ,
59+ )
60+ )
61+ fig .update_layout (title = title , xaxis_title = xaxis , yaxis_title = yaxis )
62+ placeholder .write (fig )
4663
4764 def on_train_batch_end (self , batch , logs = None ):
48-
49- loss = logs ["loss" ]
50- accuracy = logs ["categorical_accuracy" ]
51- precision = logs ["precision" ]
52- recall = logs ["recall" ]
53-
54- self .__stream_to_graph (self .loss_chart , loss )
55- self .__stream_to_graph (self .acc_precision_recall_chart , accuracy )
56- self .__update_progress_bar (batch )
65+ self .batch_progress .progress (batch / self .num_steps )
66+
67+ def on_epoch_begin (self , epoch , logs = None ):
68+ self .epoch_text .text (f"Epoch: { epoch + 1 } " )
69+
70+ def on_train_begin (self , logs = None ):
71+ self .status_text .info (
72+ "Training Started! Live Graphs will be shown on the completion of the First Epoch"
73+ )
74+
75+ def on_train_end (self , logs = None ):
76+ self .status_text .success ("Training Completed!" )
77+ st .balloons ()
78+
79+ def on_epoch_end (self , epoch , logs = None ):
80+
81+ self .train_losses .append (logs ["loss" ])
82+ self .val_losses .append (logs ["val_loss" ])
83+ self .train_accuracies .append (logs ["categorical_accuracy" ])
84+ self .val_accuracies .append (logs ["val_categorical_accuracy" ])
85+
86+ self .update_graph (
87+ self .loss_chart ,
88+ {"Train Loss" : self .train_losses , "Val Loss" : self .val_losses },
89+ "Loss Curves" ,
90+ "Epochs" ,
91+ "Loss" ,
92+ )
93+
94+ self .update_graph (
95+ self .accuracy_chart ,
96+ {
97+ "Train Accuracy" : self .train_accuracies ,
98+ "Val Accuracy" : self .val_accuracies ,
99+ },
100+ "Accuracy Curves" ,
101+ "Epochs" ,
102+ "Accuracy" ,
103+ )
57104
58105
59106st .title ("Zero Code Tensorflow Classifier Trainer" )
@@ -64,11 +111,11 @@ def on_train_batch_end(self, batch, logs=None):
64111 # Enter Path for Train and Val Dataset
65112 train_data_dir = st .text_input (
66113 "Train Data Directory (Absolute Path)" ,
67- "/home/ani/Documents/pycodes/Dataset/gender/Training /" ,
114+ "/home/ani/Documents/pycodes/Dataset/gender/Sample /" ,
68115 )
69116 val_data_dir = st .text_input (
70117 "Validation Data Directory (Absolute Path)" ,
71- "/home/ani/Documents/pycodes/Dataset/gender/Validation /" ,
118+ "/home/ani/Documents/pycodes/Dataset/gender/Sample /" ,
72119 )
73120
74121 # Enter Path for Model Weights and Training Logs (Tensorboard)
@@ -86,7 +133,7 @@ def on_train_batch_end(self, batch, logs=None):
86133 selected_batch_size = st .select_slider ("Train/Eval Batch Size" , BATCH_SIZES , 16 )
87134
88135 # Select Number of Epochs
89- selected_epochs = st .number_input ("Max Number of Epochs" , 100 )
136+ selected_epochs = st .number_input ("Max Number of Epochs" , 1 , 300000 , 100 )
90137
91138 # Start Training Button
92139 start_training = st .button ("Start Training" )
@@ -96,14 +143,14 @@ def on_train_batch_end(self, batch, logs=None):
96143 data_dir = train_data_dir ,
97144 image_dims = (224 , 224 ),
98145 grayscale = False ,
99- num_min_samples = 1000 ,
146+ num_min_samples = 100 ,
100147 )
101148
102149 val_data_loader = ImageClassificationDataLoader (
103150 data_dir = val_data_dir ,
104151 image_dims = (224 , 224 ),
105152 grayscale = False ,
106- num_min_samples = 1000 ,
153+ num_min_samples = 100 ,
107154 )
108155
109156 train_generator = train_data_loader .dataset_generator (
@@ -114,7 +161,7 @@ def on_train_batch_end(self, batch, logs=None):
114161 )
115162
116163 classifier = ImageClassifier (
117- backbone = "ResNet50V2 " ,
164+ backbone = "EfficientNetB0 " ,
118165 input_shape = (224 , 224 , 3 ),
119166 classes = train_data_loader .get_num_classes (),
120167 )
0 commit comments