Skip to content

Commit e3a8477

Browse files
committed
Added Streamlit Dashboard - Basic
1 parent a2fd1a8 commit e3a8477

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

main.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
3+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
4+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
5+
6+
from utils.data_loader import ImageClassificationDataLoader
7+
from utils.model import ImageClassifier
8+
import tensorflow as tf
9+
import streamlit as st
10+
import numpy as np
11+
import pandas as pd
12+
import time
13+
14+
# TODO: Add Support For Learning Rate Change
15+
# TODO: Add Support For Dynamic Polt.ly Charts
16+
17+
OPTIMIZERS = {
18+
"SGD": tf.keras.optimizers.SGD(),
19+
"RMSprop": tf.keras.optimizers.RMSprop(),
20+
"Adam": tf.keras.optimizers.Adam(),
21+
"Adadelta": tf.keras.optimizers.Adadelta(),
22+
"Adagrad": tf.keras.optimizers.Adagrad(),
23+
"Adamax": tf.keras.optimizers.Adamax(),
24+
"Nadam": tf.keras.optimizers.Nadam(),
25+
"FTRL": tf.keras.optimizers.Ftrl(),
26+
}
27+
28+
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256]
29+
30+
31+
class CustomCallback(tf.keras.callbacks.Callback):
32+
def __init__(self, total_steps):
33+
self.total_steps = total_steps
34+
self.loss_chart = st.line_chart(pd.DataFrame({"Loss": []}))
35+
self.acc_precision_recall_chart = st.line_chart()
36+
self.batch_progress = st.progress(0)
37+
38+
super().__init__()
39+
40+
def __stream_to_graph(self, chart_obj, values):
41+
chart_obj.add_rows(np.array([values]))
42+
43+
def __update_progress_bar(self, batch):
44+
current_progress = int(batch / self.total_steps * 100)
45+
self.batch_progress.progress(current_progress)
46+
47+
def on_train_batch_end(self, batch, logs=None):
48+
49+
loss = logs["loss"]
50+
accuracy = logs["categorical_accuracy"]
51+
precision = logs["precision"]
52+
recall = logs["recall"]
53+
54+
self.__stream_to_graph(self.loss_chart, loss)
55+
self.__stream_to_graph(self.acc_precision_recall_chart, accuracy)
56+
self.__update_progress_bar(batch)
57+
58+
59+
st.title("Zero Code Tensorflow Classifier Trainer")
60+
61+
with st.sidebar:
62+
st.header("Training Configuration")
63+
64+
# Enter Path for Train and Val Dataset
65+
train_data_dir = st.text_input(
66+
"Train Data Directory (Absolute Path)",
67+
"/home/ani/Documents/pycodes/Dataset/gender/Training/",
68+
)
69+
val_data_dir = st.text_input(
70+
"Validation Data Directory (Absolute Path)",
71+
"/home/ani/Documents/pycodes/Dataset/gender/Validation/",
72+
)
73+
74+
# Enter Path for Model Weights and Training Logs (Tensorboard)
75+
keras_weights_path = st.text_input(
76+
"Keras Weights File Path (Absolute Path)", "logs/models/weights.h5"
77+
)
78+
tensorboard_logs_path = st.text_input(
79+
"Tensorboard Logs Directory (Absolute Path)", "logs/tensorboard"
80+
)
81+
82+
# Select Optimizer
83+
selected_optimizer = st.selectbox("Training Optimizer", list(OPTIMIZERS.keys()))
84+
85+
# Select Batch Size
86+
selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16)
87+
88+
# Select Number of Epochs
89+
selected_epochs = st.number_input("Max Number of Epochs", 100)
90+
91+
# Start Training Button
92+
start_training = st.button("Start Training")
93+
94+
if start_training:
95+
train_data_loader = ImageClassificationDataLoader(
96+
data_dir=train_data_dir,
97+
image_dims=(224, 224),
98+
grayscale=False,
99+
num_min_samples=1000,
100+
)
101+
102+
val_data_loader = ImageClassificationDataLoader(
103+
data_dir=val_data_dir,
104+
image_dims=(224, 224),
105+
grayscale=False,
106+
num_min_samples=1000,
107+
)
108+
109+
train_generator = train_data_loader.dataset_generator(
110+
batch_size=selected_batch_size, augment=True
111+
)
112+
val_generator = val_data_loader.dataset_generator(
113+
batch_size=selected_batch_size, augment=False
114+
)
115+
116+
classifier = ImageClassifier(
117+
backbone="ResNet50V2",
118+
input_shape=(224, 224, 3),
119+
classes=train_data_loader.get_num_classes(),
120+
)
121+
122+
classifier.set_keras_weights_path(keras_weights_path)
123+
classifier.set_tensorboard_path(tensorboard_logs_path)
124+
125+
classifier.init_callbacks(
126+
keras_weights_path,
127+
tensorboard_logs_path,
128+
[CustomCallback(train_data_loader.get_num_steps())],
129+
)
130+
131+
classifier.set_optimizer(selected_optimizer)
132+
133+
classifier.train(
134+
train_generator,
135+
train_data_loader.get_num_steps(),
136+
val_generator,
137+
val_data_loader.get_num_steps(),
138+
epochs=selected_epochs,
139+
print_summary=False,
140+
)

0 commit comments

Comments
 (0)