Skip to content

Commit 323595a

Browse files
committed
Minor Re-shuffling
1 parent cd3286f commit 323595a

File tree

5 files changed

+184
-158
lines changed

5 files changed

+184
-158
lines changed

CHANGELOG.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7+
## Unreleased
78

8-
## Version 0.1.0 Beta - 2021-04-06
9+
### Modified
10+
- Re-shuffling:
11+
- `utils/data_loader.py` -> `core/data_loader.py`
12+
- `utils/model.py` -> `core/model.py`
13+
- Moved Custom Callbacks to new file: `utils/add_ons.py`
14+
15+
16+
## [Version 0.1.0 Beta] - 2021-04-06
917
### Added
1018
- Dockerfile
1119
- Launch Script (Dependency: [GNU Parallel](https://www.gnu.org/software/parallel/))
@@ -80,3 +88,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
8088
- Adamax
8189
- Nadam
8290
- FTRL
91+
92+
93+
[Version 0.1.0 Beta]: https://github.com/animikhaich/Zero-Code-TF-Classifier/releases/tag/v0.1-beta
File renamed without changes.
File renamed without changes.

main.py

Lines changed: 3 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,11 @@
1212
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
1313
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
1414

15-
from utils.data_loader import ImageClassificationDataLoader
16-
from utils.model import ImageClassifier
17-
from threading import Thread
15+
from core.data_loader import ImageClassificationDataLoader
16+
from core.model import ImageClassifier
17+
from utils.add_ons import CustomCallback
1818
import tensorflow as tf
1919
import streamlit as st
20-
import numpy as np
21-
import pandas as pd
22-
import plotly.graph_objs as go
2320

2421
# TODO: Add Support For Live Training Graphs (on_train_batch_end) without slowing down the Training Process
2522
# TODO: Add Supoort For EfficientNet - Fix Data Loader Input to be Un-Normalized Images
@@ -104,157 +101,6 @@
104101
st.title("Zero Code Tensorflow Classifier Trainer")
105102

106103

107-
class CustomCallback(tf.keras.callbacks.Callback):
108-
"""
109-
CustomCallback Keras Callback to Send Updates to Streamlit Dashboard
110-
111-
- Inherits from tf.keras.callbacks.Callback class
112-
- Sends Live Updates to the Dashboard
113-
- Allows Plotting Live Loss and Accuracy Curves
114-
- Allows Updating of Progress bar to track batch progress
115-
- Live plot only support Epoch Loss & Accuracy to improve training speed
116-
"""
117-
118-
def __init__(self, num_steps):
119-
"""
120-
__init__
121-
122-
Value Initializations
123-
124-
Args:
125-
num_steps (int): Total Number of Steps per Epoch
126-
"""
127-
self.num_steps = num_steps
128-
129-
# Constants (TODO: Need to Optimize)
130-
self.train_losses = []
131-
self.val_losses = []
132-
self.train_accuracies = []
133-
self.val_accuracies = []
134-
135-
# Progress
136-
self.epoch_text = st.empty()
137-
self.batch_progress = st.progress(0)
138-
self.status_text = st.empty()
139-
140-
# Charts
141-
self.loss_chart = st.empty()
142-
self.accuracy_chart = st.empty()
143-
144-
def update_graph(self, placeholder, items, title, xaxis, yaxis):
145-
"""
146-
update_graph Function to Update the plot.ly graphs on Streamlit
147-
148-
- Updates the Graphs Whenever called with the passed values
149-
- Only supports Line plots for now
150-
151-
Args:
152-
placeholder (st.empty()): streamlit placeholder object
153-
items (dict): Containing Name of the plot and values
154-
title (str): Title of the Plot
155-
xaxis (str): X-Axis Label
156-
yaxis (str): Y-Axis Label
157-
"""
158-
fig = go.Figure()
159-
for key in items.keys():
160-
fig.add_trace(
161-
go.Scatter(
162-
y=items[key],
163-
mode="lines+markers",
164-
name=key,
165-
)
166-
)
167-
fig.update_layout(title=title, xaxis_title=xaxis, yaxis_title=yaxis)
168-
placeholder.write(fig)
169-
170-
def on_train_batch_end(self, batch, logs=None):
171-
"""
172-
on_train_batch_end Update Progress Bar
173-
174-
At the end of each Training Batch, Update the progress bar
175-
176-
Args:
177-
batch (int): Current batch number
178-
logs (dict, optional): Training Metrics. Defaults to None.
179-
"""
180-
self.batch_progress.progress(batch / self.num_steps)
181-
182-
def on_epoch_begin(self, epoch, logs=None):
183-
"""
184-
on_epoch_begin
185-
186-
Update the Dashboard on the Current Epoch Number
187-
188-
Args:
189-
batch (int): Current batch number
190-
logs (dict, optional): Training Metrics. Defaults to None.
191-
"""
192-
self.epoch_text.text(f"Epoch: {epoch + 1}")
193-
194-
def on_train_begin(self, logs=None):
195-
"""
196-
on_train_begin
197-
198-
Status Update for the Dashboard with a message that training has started
199-
200-
Args:
201-
batch (int): Current batch number
202-
logs (dict, optional): Training Metrics. Defaults to None.
203-
"""
204-
self.status_text.info(
205-
"Training Started! Live Graphs will be shown on the completion of the First Epoch."
206-
)
207-
208-
def on_train_end(self, logs=None):
209-
"""
210-
on_train_end
211-
212-
Status Update for the Dashboard with a message that training has ended
213-
214-
Args:
215-
batch (int): Current batch number
216-
logs (dict, optional): Training Metrics. Defaults to None.
217-
"""
218-
self.status_text.success(
219-
f"Training Completed! Final Validation Accuracy: {logs['val_categorical_accuracy']*100:.2f}%"
220-
)
221-
st.balloons()
222-
223-
def on_epoch_end(self, epoch, logs=None):
224-
"""
225-
on_epoch_end
226-
227-
Update the Graphs with the train & val loss & accuracy curves (metrics)
228-
229-
Args:
230-
batch (int): Current batch number
231-
logs (dict, optional): Training Metrics. Defaults to None.
232-
"""
233-
self.train_losses.append(logs["loss"])
234-
self.val_losses.append(logs["val_loss"])
235-
self.train_accuracies.append(logs["categorical_accuracy"])
236-
self.val_accuracies.append(logs["val_categorical_accuracy"])
237-
238-
self.update_graph(
239-
self.loss_chart,
240-
{"Train Loss": self.train_losses, "Val Loss": self.val_losses},
241-
"Loss Curves",
242-
"Epochs",
243-
"Loss",
244-
)
245-
246-
self.update_graph(
247-
self.accuracy_chart,
248-
{
249-
"Train Accuracy": self.train_accuracies,
250-
"Val Accuracy": self.val_accuracies,
251-
},
252-
"Accuracy Curves",
253-
"Epochs",
254-
"Accuracy",
255-
)
256-
257-
258104
# Sidebar Configuration Parameters
259105
with st.sidebar:
260106
st.header("Training Configuration")

utils/add_ons.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
__author__ = "Animikh Aich"
2+
__copyright__ = "Copyright 2021, Animikh Aich"
3+
__credits__ = ["Animikh Aich"]
4+
__license__ = "MIT"
5+
__version__ = "0.1.0"
6+
__maintainer__ = "Animikh Aich"
7+
__email__ = "[email protected]"
8+
__status__ = "staging"
9+
10+
import os
11+
12+
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
13+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
14+
import tensorflow as tf
15+
import streamlit as st
16+
import numpy as np
17+
import pandas as pd
18+
import plotly.graph_objs as go
19+
20+
21+
class CustomCallback(tf.keras.callbacks.Callback):
22+
"""
23+
CustomCallback Keras Callback to Send Updates to Streamlit Dashboard
24+
25+
- Inherits from tf.keras.callbacks.Callback class
26+
- Sends Live Updates to the Dashboard
27+
- Allows Plotting Live Loss and Accuracy Curves
28+
- Allows Updating of Progress bar to track batch progress
29+
- Live plot only support Epoch Loss & Accuracy to improve training speed
30+
"""
31+
32+
def __init__(self, num_steps):
33+
"""
34+
__init__
35+
36+
Value Initializations
37+
38+
Args:
39+
num_steps (int): Total Number of Steps per Epoch
40+
"""
41+
self.num_steps = num_steps
42+
43+
# Constants (TODO: Need to Optimize)
44+
self.train_losses = []
45+
self.val_losses = []
46+
self.train_accuracies = []
47+
self.val_accuracies = []
48+
49+
# Progress
50+
self.epoch_text = st.empty()
51+
self.batch_progress = st.progress(0)
52+
self.status_text = st.empty()
53+
54+
# Charts
55+
self.loss_chart = st.empty()
56+
self.accuracy_chart = st.empty()
57+
58+
def update_graph(self, placeholder, items, title, xaxis, yaxis):
59+
"""
60+
update_graph Function to Update the plot.ly graphs on Streamlit
61+
62+
- Updates the Graphs Whenever called with the passed values
63+
- Only supports Line plots for now
64+
65+
Args:
66+
placeholder (st.empty()): streamlit placeholder object
67+
items (dict): Containing Name of the plot and values
68+
title (str): Title of the Plot
69+
xaxis (str): X-Axis Label
70+
yaxis (str): Y-Axis Label
71+
"""
72+
fig = go.Figure()
73+
for key in items.keys():
74+
fig.add_trace(
75+
go.Scatter(
76+
y=items[key],
77+
mode="lines+markers",
78+
name=key,
79+
)
80+
)
81+
fig.update_layout(title=title, xaxis_title=xaxis, yaxis_title=yaxis)
82+
placeholder.write(fig)
83+
84+
def on_train_batch_end(self, batch, logs=None):
85+
"""
86+
on_train_batch_end Update Progress Bar
87+
88+
At the end of each Training Batch, Update the progress bar
89+
90+
Args:
91+
batch (int): Current batch number
92+
logs (dict, optional): Training Metrics. Defaults to None.
93+
"""
94+
self.batch_progress.progress(batch / self.num_steps)
95+
96+
def on_epoch_begin(self, epoch, logs=None):
97+
"""
98+
on_epoch_begin
99+
100+
Update the Dashboard on the Current Epoch Number
101+
102+
Args:
103+
batch (int): Current batch number
104+
logs (dict, optional): Training Metrics. Defaults to None.
105+
"""
106+
self.epoch_text.text(f"Epoch: {epoch + 1}")
107+
108+
def on_train_begin(self, logs=None):
109+
"""
110+
on_train_begin
111+
112+
Status Update for the Dashboard with a message that training has started
113+
114+
Args:
115+
batch (int): Current batch number
116+
logs (dict, optional): Training Metrics. Defaults to None.
117+
"""
118+
self.status_text.info(
119+
"Training Started! Live Graphs will be shown on the completion of the First Epoch."
120+
)
121+
122+
def on_train_end(self, logs=None):
123+
"""
124+
on_train_end
125+
126+
Status Update for the Dashboard with a message that training has ended
127+
128+
Args:
129+
batch (int): Current batch number
130+
logs (dict, optional): Training Metrics. Defaults to None.
131+
"""
132+
self.status_text.success(
133+
f"Training Completed! Final Validation Accuracy: {logs['val_categorical_accuracy']*100:.2f}%"
134+
)
135+
st.balloons()
136+
137+
def on_epoch_end(self, epoch, logs=None):
138+
"""
139+
on_epoch_end
140+
141+
Update the Graphs with the train & val loss & accuracy curves (metrics)
142+
143+
Args:
144+
batch (int): Current batch number
145+
logs (dict, optional): Training Metrics. Defaults to None.
146+
"""
147+
self.train_losses.append(logs["loss"])
148+
self.val_losses.append(logs["val_loss"])
149+
self.train_accuracies.append(logs["categorical_accuracy"])
150+
self.val_accuracies.append(logs["val_categorical_accuracy"])
151+
152+
self.update_graph(
153+
self.loss_chart,
154+
{"Train Loss": self.train_losses, "Val Loss": self.val_losses},
155+
"Loss Curves",
156+
"Epochs",
157+
"Loss",
158+
)
159+
160+
self.update_graph(
161+
self.accuracy_chart,
162+
{
163+
"Train Accuracy": self.train_accuracies,
164+
"Val Accuracy": self.val_accuracies,
165+
},
166+
"Accuracy Curves",
167+
"Epochs",
168+
"Accuracy",
169+
)

0 commit comments

Comments
 (0)