|
12 | 12 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" |
13 | 13 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1" |
14 | 14 |
|
15 | | -from utils.data_loader import ImageClassificationDataLoader |
16 | | -from utils.model import ImageClassifier |
17 | | -from threading import Thread |
| 15 | +from core.data_loader import ImageClassificationDataLoader |
| 16 | +from core.model import ImageClassifier |
| 17 | +from utils.add_ons import CustomCallback |
18 | 18 | import tensorflow as tf |
19 | 19 | import streamlit as st |
20 | | -import numpy as np |
21 | | -import pandas as pd |
22 | | -import plotly.graph_objs as go |
23 | 20 |
|
24 | 21 | # TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process |
25 | 22 | # TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images |
|
104 | 101 | st.title("Zero Code Tensorflow Classifier Trainer") |
105 | 102 |
|
106 | 103 |
|
107 | | -class CustomCallback(tf.keras.callbacks.Callback): |
108 | | - """ |
109 | | - CustomCallback Keras Callback to Send Updates to Streamlit Dashboard |
110 | | -
|
111 | | - - Inherits from tf.keras.callbacks.Callback class |
112 | | - - Sends Live Updates to the Dashboard |
113 | | - - Allows Plotting Live Loss and Accuracy Curves |
114 | | - - Allows Updating of Progress bar to track batch progress |
115 | | - - Live plot only support Epoch Loss & Accuracy to improve training speed |
116 | | - """ |
117 | | - |
118 | | - def __init__(self, num_steps): |
119 | | - """ |
120 | | - __init__ |
121 | | -
|
122 | | - Value Initializations |
123 | | -
|
124 | | - Args: |
125 | | - num_steps (int): Total Number of Steps per Epoch |
126 | | - """ |
127 | | - self.num_steps = num_steps |
128 | | - |
129 | | - # Constants (TODO: Need to Optimize) |
130 | | - self.train_losses = [] |
131 | | - self.val_losses = [] |
132 | | - self.train_accuracies = [] |
133 | | - self.val_accuracies = [] |
134 | | - |
135 | | - # Progress |
136 | | - self.epoch_text = st.empty() |
137 | | - self.batch_progress = st.progress(0) |
138 | | - self.status_text = st.empty() |
139 | | - |
140 | | - # Charts |
141 | | - self.loss_chart = st.empty() |
142 | | - self.accuracy_chart = st.empty() |
143 | | - |
144 | | - def update_graph(self, placeholder, items, title, xaxis, yaxis): |
145 | | - """ |
146 | | - update_graph Function to Update the plot.ly graphs on Streamlit |
147 | | -
|
148 | | - - Updates the Graphs Whenever called with the passed values |
149 | | - - Only supports Line plots for now |
150 | | -
|
151 | | - Args: |
152 | | - placeholder (st.empty()): streamlit placeholder object |
153 | | - items (dict): Containing Name of the plot and values |
154 | | - title (str): Title of the Plot |
155 | | - xaxis (str): X-Axis Label |
156 | | - yaxis (str): Y-Axis Label |
157 | | - """ |
158 | | - fig = go.Figure() |
159 | | - for key in items.keys(): |
160 | | - fig.add_trace( |
161 | | - go.Scatter( |
162 | | - y=items[key], |
163 | | - mode="lines+markers", |
164 | | - name=key, |
165 | | - ) |
166 | | - ) |
167 | | - fig.update_layout(title=title, xaxis_title=xaxis, yaxis_title=yaxis) |
168 | | - placeholder.write(fig) |
169 | | - |
170 | | - def on_train_batch_end(self, batch, logs=None): |
171 | | - """ |
172 | | - on_train_batch_end Update Progress Bar |
173 | | -
|
174 | | - At the end of each Training Batch, Update the progress bar |
175 | | -
|
176 | | - Args: |
177 | | - batch (int): Current batch number |
178 | | - logs (dict, optional): Training Metrics. Defaults to None. |
179 | | - """ |
180 | | - self.batch_progress.progress(batch / self.num_steps) |
181 | | - |
182 | | - def on_epoch_begin(self, epoch, logs=None): |
183 | | - """ |
184 | | - on_epoch_begin |
185 | | -
|
186 | | - Update the Dashboard on the Current Epoch Number |
187 | | -
|
188 | | - Args: |
189 | | - batch (int): Current batch number |
190 | | - logs (dict, optional): Training Metrics. Defaults to None. |
191 | | - """ |
192 | | - self.epoch_text.text(f"Epoch: {epoch + 1}") |
193 | | - |
194 | | - def on_train_begin(self, logs=None): |
195 | | - """ |
196 | | - on_train_begin |
197 | | -
|
198 | | - Status Update for the Dashboard with a message that training has started |
199 | | -
|
200 | | - Args: |
201 | | - batch (int): Current batch number |
202 | | - logs (dict, optional): Training Metrics. Defaults to None. |
203 | | - """ |
204 | | - self.status_text.info( |
205 | | - "Training Started! Live Graphs will be shown on the completion of the First Epoch." |
206 | | - ) |
207 | | - |
208 | | - def on_train_end(self, logs=None): |
209 | | - """ |
210 | | - on_train_end |
211 | | -
|
212 | | - Status Update for the Dashboard with a message that training has ended |
213 | | -
|
214 | | - Args: |
215 | | - batch (int): Current batch number |
216 | | - logs (dict, optional): Training Metrics. Defaults to None. |
217 | | - """ |
218 | | - self.status_text.success( |
219 | | - f"Training Completed! Final Validation Accuracy: {logs['val_categorical_accuracy']*100:.2f}%" |
220 | | - ) |
221 | | - st.balloons() |
222 | | - |
223 | | - def on_epoch_end(self, epoch, logs=None): |
224 | | - """ |
225 | | - on_epoch_end |
226 | | -
|
227 | | - Update the Graphs with the train & val loss & accuracy curves (metrics) |
228 | | -
|
229 | | - Args: |
230 | | - batch (int): Current batch number |
231 | | - logs (dict, optional): Training Metrics. Defaults to None. |
232 | | - """ |
233 | | - self.train_losses.append(logs["loss"]) |
234 | | - self.val_losses.append(logs["val_loss"]) |
235 | | - self.train_accuracies.append(logs["categorical_accuracy"]) |
236 | | - self.val_accuracies.append(logs["val_categorical_accuracy"]) |
237 | | - |
238 | | - self.update_graph( |
239 | | - self.loss_chart, |
240 | | - {"Train Loss": self.train_losses, "Val Loss": self.val_losses}, |
241 | | - "Loss Curves", |
242 | | - "Epochs", |
243 | | - "Loss", |
244 | | - ) |
245 | | - |
246 | | - self.update_graph( |
247 | | - self.accuracy_chart, |
248 | | - { |
249 | | - "Train Accuracy": self.train_accuracies, |
250 | | - "Val Accuracy": self.val_accuracies, |
251 | | - }, |
252 | | - "Accuracy Curves", |
253 | | - "Epochs", |
254 | | - "Accuracy", |
255 | | - ) |
256 | | - |
257 | | - |
258 | 104 | # Sidebar Configuration Parameters |
259 | 105 | with st.sidebar: |
260 | 106 | st.header("Training Configuration") |
|
0 commit comments