Skip to content

Commit 94ee052

Browse files
committed
Added Documentation for Data Loader
1 parent dc5e0e3 commit 94ee052

File tree

1 file changed

+199
-6
lines changed

1 file changed

+199
-6
lines changed

utils/data_loader.py

Lines changed: 199 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,55 @@
1414
import matplotlib.pyplot as plt
1515
import tensorflow as tf
1616
from glob import glob
17-
import numpy as np
1817

1918

2019
# TODO: Add Augmentations from Albumentations (https://github.com/albumentations-team/albumentations)
2120
# TODO: Add Tunable Augmentation Loading from a Config File
22-
# TODO: Add Check for num_min_samples
2321

2422

2523
class ImageClassificationDataLoader:
26-
__supported_im_formats = [".jpg", ".jpeg", ".png"]
24+
"""
25+
Data Loader for Image Classification
26+
27+
- Optimized Tf.Data implementation for maximum GPU usage
28+
- Automatically handle errors such as corrupted images
29+
- Built-in Dataset Verification
30+
- Built-in Checks for if dataset is of a supported format
31+
- Supports Auto Detect Sub-folders get class information
32+
- Auto Generate Class Label Map
33+
- Built in Image Augmentation
34+
- Dataset Batch Visualization (With and Without Augment)
35+
36+
37+
Raises:
38+
ValueError: Dataset Directory Path is Invalid
39+
ValueError: Raise when unsupported files are detected
40+
ValueError: Raise when Number of images are less than minimum specified
41+
AssertionError: get_batch_size method is called without initializing dataset_generator
42+
"""
43+
44+
__supported_im_formats = [".jpg", ".jpeg", ".png", ".bmp"]
2745

2846
def __init__(
29-
self, data_dir, image_dims=(224, 224), grayscale=False, num_min_samples=500
47+
self,
48+
data_dir: str,
49+
image_dims: tuple = (224, 224),
50+
grayscale: bool = False,
51+
num_min_samples: int = 500,
3052
) -> None:
53+
"""
54+
__init__
55+
56+
- Instance Variable Initialization
57+
- Dataset Verification
58+
- Listing all files in the given path
59+
60+
Args:
61+
data_dir (str): Path to the Dataset Directory
62+
image_dims (tuple, optional): Image Dimensions (width & height). Defaults to (224, 224).
63+
grayscale (bool, optional): If Grayscale, Select Single Channel, else RGB. Defaults to False.
64+
num_min_samples (int, optional): Minimum Number of Required Images per Class. Defaults to 500.
65+
"""
3166

3267
self.BATCH_SIZE = None
3368
self.LABELS = []
@@ -44,14 +79,26 @@ def __init__(
4479
)
4580

4681
def __dataset_verification(self) -> bool:
82+
"""
83+
__dataset_verification
84+
85+
Dataset Verification & Checks
86+
87+
Raises:
88+
ValueError: Dataset Directory Path is Invalid
89+
ValueError: Raise when unsupported files are detected
90+
ValueError: Raise when Number of images are less than minimum specified
91+
Returns:
92+
bool: True if all checks are passed
93+
"""
4794
# Check if the given directory is a valid directory path
4895
if not os.path.isdir(self.DATA_DIR):
4996
raise ValueError(f"Data Directory Path is Invalid")
5097

5198
# Assume the directory names as label names and get the label names
5299
self.LABELS = self.extract_labels()
53100

54-
# Check if all files in each folder is an image. If not, raise an alert
101+
# Check if all files in each folder is an image
55102
format_issues = {}
56103
quant_issues = {}
57104
for label in self.LABELS:
@@ -68,12 +115,13 @@ def __dataset_verification(self) -> bool:
68115

69116
quant_issues[label] = len(paths)
70117

71-
# Check if any of the dict values has any entry. If any entry is there, raise alert
118+
# Check if any of the classes have files that are not supported
72119
if any([len(format_issues[key]) for key in format_issues.keys()]):
73120
raise ValueError(
74121
f"Invalid File(s) Detected: {format_issues}\n\nSupported Formats: {self.__supported_im_formats}"
75122
)
76123

124+
# Check if any of the classes have number of images less than the minimum
77125
if any(
78126
[quant_issues[key] < self.NUM_MIN_SAMPLES for key in quant_issues.keys()]
79127
):
@@ -89,6 +137,14 @@ def __dataset_verification(self) -> bool:
89137
return True
90138

91139
def extract_labels(self) -> list:
140+
"""
141+
extract_labels
142+
143+
Extract the labels from the directory path (Folder Names)
144+
145+
Returns:
146+
list: List of Class Labels
147+
"""
92148
labels = [
93149
label
94150
for label in os.listdir(self.DATA_DIR)
@@ -97,10 +153,36 @@ def extract_labels(self) -> list:
97153
return labels
98154

99155
def get_encoded_labels(self, file_path) -> list:
156+
"""
157+
get_encoded_labels
158+
159+
Get the One-Hot Version of the Labels
160+
161+
Args:
162+
file_path (str): Complete path of the Image
163+
164+
Returns:
165+
list: One Hot Encoded Labelmap
166+
"""
100167
parts = tf.strings.split(file_path, os.path.sep)
101168
return parts[-2] == self.LABELS
102169

103170
def load_image(self, file_path) -> tuple:
171+
"""
172+
load_image
173+
174+
Read Image for Tf.Data
175+
- Read File & Load the Image
176+
- Decode the Image (jpg, png, bmp)
177+
- Normalize & Resize the Image
178+
- Match Input & Label and return
179+
180+
Args:
181+
file_path (str): Path of the Image File
182+
183+
Returns:
184+
tuple: Image and One-Hot Encoded Label
185+
"""
104186
label = self.get_encoded_labels(file_path)
105187
img = tf.io.read_file(file_path)
106188
img = tf.io.decode_image(
@@ -112,6 +194,24 @@ def load_image(self, file_path) -> tuple:
112194
return img, tf.cast(label, tf.float32)
113195

114196
def augment_batch(self, image, label) -> tuple:
197+
"""
198+
augment_batch
199+
200+
Image Augmentation for Training:
201+
- Random Contrast
202+
- Random Brightness
203+
- Random Hue (Color)
204+
- Random Saturation
205+
- Random Horizontal Flip
206+
- Random Reduction in Image Quality
207+
208+
Args:
209+
image (Tensor Image): Raw Image
210+
label (Tensor): One-Hot Encoded Label
211+
212+
Returns:
213+
tuple: Raw Image, One-Hot Encoded Label
214+
"""
115215
if tf.random.normal([1]) < 0:
116216
image = tf.image.random_contrast(image, 0.2, 0.9)
117217
if tf.random.normal([1]) < 0:
@@ -127,27 +227,98 @@ def augment_batch(self, image, label) -> tuple:
127227
return image, label
128228

129229
def get_supported_formats(self) -> str:
230+
"""
231+
get_supported_formats
232+
233+
Get the Supported Image Formats (String)
234+
235+
Returns:
236+
str: Supported Image Formats
237+
"""
130238
return f"Supported File Extensions: {', '.join(self.__supported_im_formats)}"
131239

132240
def get_supported_formats_list(self) -> list:
241+
"""
242+
get_supported_formats
243+
244+
Get the Supported Image Formats (List)
245+
246+
Returns:
247+
list: Supported Image Formats
248+
"""
133249
return self.__supported_im_formats
134250

135251
def get_labelmap(self) -> dict:
252+
"""
253+
get_labelmap
254+
255+
Get the Labelmap for the Classes
256+
Returns a List of Dictionaries containing the details
257+
Sample:
258+
[
259+
{
260+
"id": 0,
261+
"name": "male"
262+
},
263+
{
264+
"id": 1,
265+
"name": "female"
266+
}
267+
]
268+
269+
Returns:
270+
dict: Labelmap (ID and Label)
271+
"""
136272
labelmap = []
137273
for i, label in enumerate(self.LABELS):
138274
labelmap.append({"id": i, "name": label})
139275
return labelmap
140276

141277
def get_labels(self) -> list:
278+
"""
279+
get_labels
280+
281+
Get List of Labels (Class Names)
282+
283+
Returns:
284+
list: List of Labels (Class Names)
285+
"""
142286
return self.LABELS
143287

144288
def get_num_classes(self) -> int:
289+
"""
290+
get_num_classes
291+
292+
Get Total Number of Classes
293+
294+
Returns:
295+
int: Number of Classes (Labels)
296+
"""
145297
return len(self.LABELS)
146298

147299
def get_dataset_size(self) -> int:
300+
"""
301+
get_dataset_size
302+
303+
Get the Dataset Size (Number of Images)
304+
305+
Returns:
306+
int: Total Number of images in Dataset
307+
"""
148308
return len(list(self.dataset_files))
149309

150310
def get_num_steps(self) -> int:
311+
"""
312+
get_num_steps
313+
314+
Get the Number of Steps Required per Batch for Training
315+
316+
Raises:
317+
AssertionError: Dataset Generator needs to be Initialized First
318+
319+
Returns:
320+
int: Number of Steps Required for Training Per Batch
321+
"""
151322
if self.BATCH_SIZE is None:
152323
raise AssertionError(
153324
f"Batch Size is not Initialized. Call this method only after calling: {self.dataset_generator}"
@@ -156,6 +327,18 @@ def get_num_steps(self) -> int:
156327
return num_steps
157328

158329
def dataset_generator(self, batch_size=32, augment=False):
330+
"""
331+
dataset_generator
332+
333+
Create the Data Loader Pipeline and Return a Generator to Generate Datsets
334+
335+
Args:
336+
batch_size (int, optional): Batch Size. Defaults to 32.
337+
augment (bool, optional): Enable/Disable Augmentation. Defaults to False.
338+
339+
Returns:
340+
Tf.Data Generator: Dataset Generator
341+
"""
159342
self.BATCH_SIZE = batch_size
160343

161344
dataset = self.dataset_files.map(
@@ -174,6 +357,16 @@ def dataset_generator(self, batch_size=32, augment=False):
174357
return dataset
175358

176359
def visualize_batch(self, augment=True) -> None:
360+
"""
361+
visualize_batch
362+
363+
Dataset Sample Visualization
364+
- Supports Augmentation
365+
- Automatically Adjusts for Grayscale Images
366+
367+
Args:
368+
augment (bool, optional): Enable/Disable Augmentation. Defaults to True.
369+
"""
177370
if self.NUM_CHANNELS == 1:
178371
cmap = "gray"
179372
else:

0 commit comments

Comments
 (0)