Skip to content

Commit 1c17c0f

Browse files
committed
Added Plotly Graphs, Added Progress Bar and Status Updates
1 parent e3a8477 commit 1c17c0f

File tree

1 file changed

+75
-28
lines changed

1 file changed

+75
-28
lines changed

main.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
import streamlit as st
1010
import numpy as np
1111
import pandas as pd
12-
import time
12+
import plotly.graph_objs as go
1313

1414
# TODO: Add Support For Learning Rate Change
1515
# TODO: Add Support For Dynamic Polt.ly Charts
16+
# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
1617

1718
OPTIMIZERS = {
1819
"SGD": tf.keras.optimizers.SGD(),
@@ -29,31 +30,77 @@
2930

3031

3132
class CustomCallback(tf.keras.callbacks.Callback):
32-
def __init__(self, total_steps):
33-
self.total_steps = total_steps
34-
self.loss_chart = st.line_chart(pd.DataFrame({"Loss": []}))
35-
self.acc_precision_recall_chart = st.line_chart()
36-
self.batch_progress = st.progress(0)
37-
38-
super().__init__()
33+
def __init__(self, num_steps):
34+
self.num_steps = num_steps
3935

40-
def __stream_to_graph(self, chart_obj, values):
41-
chart_obj.add_rows(np.array([values]))
36+
# Constants (TODO: Need to Optimize)
37+
self.train_losses = []
38+
self.val_losses = []
39+
self.train_accuracies = []
40+
self.val_accuracies = []
4241

43-
def __update_progress_bar(self, batch):
44-
current_progress = int(batch / self.total_steps * 100)
45-
self.batch_progress.progress(current_progress)
42+
# Progress
43+
self.epoch_text = st.empty()
44+
self.batch_progress = st.progress(0)
45+
self.status_text = st.empty()
46+
47+
# Charts
48+
self.loss_chart = st.empty()
49+
self.accuracy_chart = st.empty()
50+
51+
def update_graph(self, placeholder, items, title, xaxis, yaxis):
52+
fig = go.Figure()
53+
for key in items.keys():
54+
fig.add_trace(
55+
go.Scatter(
56+
y=items[key],
57+
mode="lines+markers",
58+
name=key,
59+
)
60+
)
61+
fig.update_layout(title=title, xaxis_title=xaxis, yaxis_title=yaxis)
62+
placeholder.write(fig)
4663

4764
def on_train_batch_end(self, batch, logs=None):
48-
49-
loss = logs["loss"]
50-
accuracy = logs["categorical_accuracy"]
51-
precision = logs["precision"]
52-
recall = logs["recall"]
53-
54-
self.__stream_to_graph(self.loss_chart, loss)
55-
self.__stream_to_graph(self.acc_precision_recall_chart, accuracy)
56-
self.__update_progress_bar(batch)
65+
self.batch_progress.progress(batch / self.num_steps)
66+
67+
def on_epoch_begin(self, epoch, logs=None):
68+
self.epoch_text.text(f"Epoch: {epoch + 1}")
69+
70+
def on_train_begin(self, logs=None):
71+
self.status_text.info(
72+
"Training Started! Live Graphs will be shown on the completion of the First Epoch"
73+
)
74+
75+
def on_train_end(self, logs=None):
76+
self.status_text.success("Training Completed!")
77+
st.balloons()
78+
79+
def on_epoch_end(self, epoch, logs=None):
80+
81+
self.train_losses.append(logs["loss"])
82+
self.val_losses.append(logs["val_loss"])
83+
self.train_accuracies.append(logs["categorical_accuracy"])
84+
self.val_accuracies.append(logs["val_categorical_accuracy"])
85+
86+
self.update_graph(
87+
self.loss_chart,
88+
{"Train Loss": self.train_losses, "Val Loss": self.val_losses},
89+
"Loss Curves",
90+
"Epochs",
91+
"Loss",
92+
)
93+
94+
self.update_graph(
95+
self.accuracy_chart,
96+
{
97+
"Train Accuracy": self.train_accuracies,
98+
"Val Accuracy": self.val_accuracies,
99+
},
100+
"Accuracy Curves",
101+
"Epochs",
102+
"Accuracy",
103+
)
57104

58105

59106
st.title("Zero Code Tensorflow Classifier Trainer")
@@ -64,11 +111,11 @@ def on_train_batch_end(self, batch, logs=None):
64111
# Enter Path for Train and Val Dataset
65112
train_data_dir = st.text_input(
66113
"Train Data Directory (Absolute Path)",
67-
"/home/ani/Documents/pycodes/Dataset/gender/Training/",
114+
"/home/ani/Documents/pycodes/Dataset/gender/Sample/",
68115
)
69116
val_data_dir = st.text_input(
70117
"Validation Data Directory (Absolute Path)",
71-
"/home/ani/Documents/pycodes/Dataset/gender/Validation/",
118+
"/home/ani/Documents/pycodes/Dataset/gender/Sample/",
72119
)
73120

74121
# Enter Path for Model Weights and Training Logs (Tensorboard)
@@ -86,7 +133,7 @@ def on_train_batch_end(self, batch, logs=None):
86133
selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16)
87134

88135
# Select Number of Epochs
89-
selected_epochs = st.number_input("Max Number of Epochs", 100)
136+
selected_epochs = st.number_input("Max Number of Epochs", 1, 300000, 100)
90137

91138
# Start Training Button
92139
start_training = st.button("Start Training")
@@ -96,14 +143,14 @@ def on_train_batch_end(self, batch, logs=None):
96143
data_dir=train_data_dir,
97144
image_dims=(224, 224),
98145
grayscale=False,
99-
num_min_samples=1000,
146+
num_min_samples=100,
100147
)
101148

102149
val_data_loader = ImageClassificationDataLoader(
103150
data_dir=val_data_dir,
104151
image_dims=(224, 224),
105152
grayscale=False,
106-
num_min_samples=1000,
153+
num_min_samples=100,
107154
)
108155

109156
train_generator = train_data_loader.dataset_generator(
@@ -114,7 +161,7 @@ def on_train_batch_end(self, batch, logs=None):
114161
)
115162

116163
classifier = ImageClassifier(
117-
backbone="ResNet50V2",
164+
backbone="EfficientNetB0",
118165
input_shape=(224, 224, 3),
119166
classes=train_data_loader.get_num_classes(),
120167
)

0 commit comments

Comments
 (0)