|
| 1 | +import os |
| 2 | + |
| 3 | +os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" |
| 4 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" |
| 5 | + |
| 6 | +from utils.data_loader import ImageClassificationDataLoader |
| 7 | +from utils.model import ImageClassifier |
| 8 | +import tensorflow as tf |
| 9 | +import streamlit as st |
| 10 | +import numpy as np |
| 11 | +import pandas as pd |
| 12 | +import time |
| 13 | + |
| 14 | +# TODO: Add Support For Learning Rate Change |
| 15 | +# TODO: Add Support For Dynamic Polt.ly Charts |
| 16 | + |
| 17 | +OPTIMIZERS = { |
| 18 | + "SGD": tf.keras.optimizers.SGD(), |
| 19 | + "RMSprop": tf.keras.optimizers.RMSprop(), |
| 20 | + "Adam": tf.keras.optimizers.Adam(), |
| 21 | + "Adadelta": tf.keras.optimizers.Adadelta(), |
| 22 | + "Adagrad": tf.keras.optimizers.Adagrad(), |
| 23 | + "Adamax": tf.keras.optimizers.Adamax(), |
| 24 | + "Nadam": tf.keras.optimizers.Nadam(), |
| 25 | + "FTRL": tf.keras.optimizers.Ftrl(), |
| 26 | +} |
| 27 | + |
| 28 | +BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256] |
| 29 | + |
| 30 | + |
| 31 | +class CustomCallback(tf.keras.callbacks.Callback): |
| 32 | + def __init__(self, total_steps): |
| 33 | + self.total_steps = total_steps |
| 34 | + self.loss_chart = st.line_chart(pd.DataFrame({"Loss": []})) |
| 35 | + self.acc_precision_recall_chart = st.line_chart() |
| 36 | + self.batch_progress = st.progress(0) |
| 37 | + |
| 38 | + super().__init__() |
| 39 | + |
| 40 | + def __stream_to_graph(self, chart_obj, values): |
| 41 | + chart_obj.add_rows(np.array([values])) |
| 42 | + |
| 43 | + def __update_progress_bar(self, batch): |
| 44 | + current_progress = int(batch / self.total_steps * 100) |
| 45 | + self.batch_progress.progress(current_progress) |
| 46 | + |
| 47 | + def on_train_batch_end(self, batch, logs=None): |
| 48 | + |
| 49 | + loss = logs["loss"] |
| 50 | + accuracy = logs["categorical_accuracy"] |
| 51 | + precision = logs["precision"] |
| 52 | + recall = logs["recall"] |
| 53 | + |
| 54 | + self.__stream_to_graph(self.loss_chart, loss) |
| 55 | + self.__stream_to_graph(self.acc_precision_recall_chart, accuracy) |
| 56 | + self.__update_progress_bar(batch) |
| 57 | + |
| 58 | + |
| 59 | +st.title("Zero Code Tensorflow Classifier Trainer") |
| 60 | + |
| 61 | +with st.sidebar: |
| 62 | + st.header("Training Configuration") |
| 63 | + |
| 64 | + # Enter Path for Train and Val Dataset |
| 65 | + train_data_dir = st.text_input( |
| 66 | + "Train Data Directory (Absolute Path)", |
| 67 | + "/home/ani/Documents/pycodes/Dataset/gender/Training/", |
| 68 | + ) |
| 69 | + val_data_dir = st.text_input( |
| 70 | + "Validation Data Directory (Absolute Path)", |
| 71 | + "/home/ani/Documents/pycodes/Dataset/gender/Validation/", |
| 72 | + ) |
| 73 | + |
| 74 | + # Enter Path for Model Weights and Training Logs (Tensorboard) |
| 75 | + keras_weights_path = st.text_input( |
| 76 | + "Keras Weights File Path (Absolute Path)", "logs/models/weights.h5" |
| 77 | + ) |
| 78 | + tensorboard_logs_path = st.text_input( |
| 79 | + "Tensorboard Logs Directory (Absolute Path)", "logs/tensorboard" |
| 80 | + ) |
| 81 | + |
| 82 | + # Select Optimizer |
| 83 | + selected_optimizer = st.selectbox("Training Optimizer", list(OPTIMIZERS.keys())) |
| 84 | + |
| 85 | + # Select Batch Size |
| 86 | + selected_batch_size = st.select_slider("Train/Eval Batch Size", BATCH_SIZES, 16) |
| 87 | + |
| 88 | + # Select Number of Epochs |
| 89 | + selected_epochs = st.number_input("Max Number of Epochs", 100) |
| 90 | + |
| 91 | + # Start Training Button |
| 92 | + start_training = st.button("Start Training") |
| 93 | + |
| 94 | +if start_training: |
| 95 | + train_data_loader = ImageClassificationDataLoader( |
| 96 | + data_dir=train_data_dir, |
| 97 | + image_dims=(224, 224), |
| 98 | + grayscale=False, |
| 99 | + num_min_samples=1000, |
| 100 | + ) |
| 101 | + |
| 102 | + val_data_loader = ImageClassificationDataLoader( |
| 103 | + data_dir=val_data_dir, |
| 104 | + image_dims=(224, 224), |
| 105 | + grayscale=False, |
| 106 | + num_min_samples=1000, |
| 107 | + ) |
| 108 | + |
| 109 | + train_generator = train_data_loader.dataset_generator( |
| 110 | + batch_size=selected_batch_size, augment=True |
| 111 | + ) |
| 112 | + val_generator = val_data_loader.dataset_generator( |
| 113 | + batch_size=selected_batch_size, augment=False |
| 114 | + ) |
| 115 | + |
| 116 | + classifier = ImageClassifier( |
| 117 | + backbone="ResNet50V2", |
| 118 | + input_shape=(224, 224, 3), |
| 119 | + classes=train_data_loader.get_num_classes(), |
| 120 | + ) |
| 121 | + |
| 122 | + classifier.set_keras_weights_path(keras_weights_path) |
| 123 | + classifier.set_tensorboard_path(tensorboard_logs_path) |
| 124 | + |
| 125 | + classifier.init_callbacks( |
| 126 | + keras_weights_path, |
| 127 | + tensorboard_logs_path, |
| 128 | + [CustomCallback(train_data_loader.get_num_steps())], |
| 129 | + ) |
| 130 | + |
| 131 | + classifier.set_optimizer(selected_optimizer) |
| 132 | + |
| 133 | + classifier.train( |
| 134 | + train_generator, |
| 135 | + train_data_loader.get_num_steps(), |
| 136 | + val_generator, |
| 137 | + val_data_loader.get_num_steps(), |
| 138 | + epochs=selected_epochs, |
| 139 | + print_summary=False, |
| 140 | + ) |
0 commit comments