1414import matplotlib .pyplot as plt
1515import tensorflow as tf
1616from glob import glob
17- import numpy as np
1817
1918
2019# TODO: Add Augmentations from Albumentations (https://github.com/albumentations-team/albumentations)
2120# TODO: Add Tunable Augmentation Loading from a Config File
22- # TODO: Add Check for num_min_samples
2321
2422
2523class ImageClassificationDataLoader :
26- __supported_im_formats = [".jpg" , ".jpeg" , ".png" ]
24+ """
25+ Data Loader for Image Classification
26+
27+ - Optimized Tf.Data implementation for maximum GPU usage
28+ - Automatically handle errors such as corrupted images
29+ - Built-in Dataset Verification
30+ - Built-in Checks for if dataset is of a supported format
31+ - Supports Auto Detect Sub-folders get class information
32+ - Auto Generate Class Label Map
33+ - Built in Image Augmentation
34+ - Dataset Batch Visualization (With and Without Augment)
35+
36+
37+ Raises:
38+ ValueError: Dataset Directory Path is Invalid
39+ ValueError: Raise when unsupported files are detected
40+ ValueError: Raise when Number of images are less than minimum specified
41+ AssertionError: get_batch_size method is called without initializing dataset_generator
42+ """
43+
44+ __supported_im_formats = [".jpg" , ".jpeg" , ".png" , ".bmp" ]
2745
2846 def __init__ (
29- self , data_dir , image_dims = (224 , 224 ), grayscale = False , num_min_samples = 500
47+ self ,
48+ data_dir : str ,
49+ image_dims : tuple = (224 , 224 ),
50+ grayscale : bool = False ,
51+ num_min_samples : int = 500 ,
3052 ) -> None :
53+ """
54+ __init__
55+
56+ - Instance Variable Initialization
57+ - Dataset Verification
58+ - Listing all files in the given path
59+
60+ Args:
61+ data_dir (str): Path to the Dataset Directory
62+ image_dims (tuple, optional): Image Dimensions (width & height). Defaults to (224, 224).
63+ grayscale (bool, optional): If Grayscale, Select Single Channel, else RGB. Defaults to False.
64+ num_min_samples (int, optional): Minimum Number of Required Images per Class. Defaults to 500.
65+ """
3166
3267 self .BATCH_SIZE = None
3368 self .LABELS = []
@@ -44,14 +79,26 @@ def __init__(
4479 )
4580
4681 def __dataset_verification (self ) -> bool :
82+ """
83+ __dataset_verification
84+
85+ Dataset Verification & Checks
86+
87+ Raises:
88+ ValueError: Dataset Directory Path is Invalid
89+ ValueError: Raise when unsupported files are detected
90+ ValueError: Raise when Number of images are less than minimum specified
91+ Returns:
92+ bool: True if all checks are passed
93+ """
4794 # Check if the given directory is a valid directory path
4895 if not os .path .isdir (self .DATA_DIR ):
4996 raise ValueError (f"Data Directory Path is Invalid" )
5097
5198 # Assume the directory names as label names and get the label names
5299 self .LABELS = self .extract_labels ()
53100
54- # Check if all files in each folder is an image. If not, raise an alert
101+ # Check if all files in each folder is an image
55102 format_issues = {}
56103 quant_issues = {}
57104 for label in self .LABELS :
@@ -68,12 +115,13 @@ def __dataset_verification(self) -> bool:
68115
69116 quant_issues [label ] = len (paths )
70117
71- # Check if any of the dict values has any entry. If any entry is there, raise alert
118+ # Check if any of the classes have files that are not supported
72119 if any ([len (format_issues [key ]) for key in format_issues .keys ()]):
73120 raise ValueError (
74121 f"Invalid File(s) Detected: { format_issues } \n \n Supported Formats: { self .__supported_im_formats } "
75122 )
76123
124+ # Check if any of the classes have number of images less than the minimum
77125 if any (
78126 [quant_issues [key ] < self .NUM_MIN_SAMPLES for key in quant_issues .keys ()]
79127 ):
@@ -89,6 +137,14 @@ def __dataset_verification(self) -> bool:
89137 return True
90138
91139 def extract_labels (self ) -> list :
140+ """
141+ extract_labels
142+
143+ Extract the labels from the directory path (Folder Names)
144+
145+ Returns:
146+ list: List of Class Labels
147+ """
92148 labels = [
93149 label
94150 for label in os .listdir (self .DATA_DIR )
@@ -97,10 +153,36 @@ def extract_labels(self) -> list:
97153 return labels
98154
99155 def get_encoded_labels (self , file_path ) -> list :
156+ """
157+ get_encoded_labels
158+
159+ Get the One-Hot Version of the Labels
160+
161+ Args:
162+ file_path (str): Complete path of the Image
163+
164+ Returns:
165+ list: One Hot Encoded Labelmap
166+ """
100167 parts = tf .strings .split (file_path , os .path .sep )
101168 return parts [- 2 ] == self .LABELS
102169
103170 def load_image (self , file_path ) -> tuple :
171+ """
172+ load_image
173+
174+ Read Image for Tf.Data
175+ - Read File & Load the Image
176+ - Decode the Image (jpg, png, bmp)
177+ - Normalize & Resize the Image
178+ - Match Input & Label and return
179+
180+ Args:
181+ file_path (str): Path of the Image File
182+
183+ Returns:
184+ tuple: Image and One-Hot Encoded Label
185+ """
104186 label = self .get_encoded_labels (file_path )
105187 img = tf .io .read_file (file_path )
106188 img = tf .io .decode_image (
@@ -112,6 +194,24 @@ def load_image(self, file_path) -> tuple:
112194 return img , tf .cast (label , tf .float32 )
113195
114196 def augment_batch (self , image , label ) -> tuple :
197+ """
198+ augment_batch
199+
200+ Image Augmentation for Training:
201+ - Random Contrast
202+ - Random Brightness
203+ - Random Hue (Color)
204+ - Random Saturation
205+ - Random Horizontal Flip
206+ - Random Reduction in Image Quality
207+
208+ Args:
209+ image (Tensor Image): Raw Image
210+ label (Tensor): One-Hot Encoded Label
211+
212+ Returns:
213+ tuple: Raw Image, One-Hot Encoded Label
214+ """
115215 if tf .random .normal ([1 ]) < 0 :
116216 image = tf .image .random_contrast (image , 0.2 , 0.9 )
117217 if tf .random .normal ([1 ]) < 0 :
@@ -127,27 +227,98 @@ def augment_batch(self, image, label) -> tuple:
127227 return image , label
128228
129229 def get_supported_formats (self ) -> str :
230+ """
231+ get_supported_formats
232+
233+ Get the Supported Image Formats (String)
234+
235+ Returns:
236+ str: Supported Image Formats
237+ """
130238 return f"Supported File Extensions: { ', ' .join (self .__supported_im_formats )} "
131239
132240 def get_supported_formats_list (self ) -> list :
241+ """
242+ get_supported_formats
243+
244+ Get the Supported Image Formats (List)
245+
246+ Returns:
247+ list: Supported Image Formats
248+ """
133249 return self .__supported_im_formats
134250
135251 def get_labelmap (self ) -> dict :
252+ """
253+ get_labelmap
254+
255+ Get the Labelmap for the Classes
256+ Returns a List of Dictionaries containing the details
257+ Sample:
258+ [
259+ {
260+ "id": 0,
261+ "name": "male"
262+ },
263+ {
264+ "id": 1,
265+ "name": "female"
266+ }
267+ ]
268+
269+ Returns:
270+ dict: Labelmap (ID and Label)
271+ """
136272 labelmap = []
137273 for i , label in enumerate (self .LABELS ):
138274 labelmap .append ({"id" : i , "name" : label })
139275 return labelmap
140276
141277 def get_labels (self ) -> list :
278+ """
279+ get_labels
280+
281+ Get List of Labels (Class Names)
282+
283+ Returns:
284+ list: List of Labels (Class Names)
285+ """
142286 return self .LABELS
143287
144288 def get_num_classes (self ) -> int :
289+ """
290+ get_num_classes
291+
292+ Get Total Number of Classes
293+
294+ Returns:
295+ int: Number of Classes (Labels)
296+ """
145297 return len (self .LABELS )
146298
147299 def get_dataset_size (self ) -> int :
300+ """
301+ get_dataset_size
302+
303+ Get the Dataset Size (Number of Images)
304+
305+ Returns:
306+ int: Total Number of images in Dataset
307+ """
148308 return len (list (self .dataset_files ))
149309
150310 def get_num_steps (self ) -> int :
311+ """
312+ get_num_steps
313+
314+ Get the Number of Steps Required per Batch for Training
315+
316+ Raises:
317+ AssertionError: Dataset Generator needs to be Initialized First
318+
319+ Returns:
320+ int: Number of Steps Required for Training Per Batch
321+ """
151322 if self .BATCH_SIZE is None :
152323 raise AssertionError (
153324 f"Batch Size is not Initialized. Call this method only after calling: { self .dataset_generator } "
@@ -156,6 +327,18 @@ def get_num_steps(self) -> int:
156327 return num_steps
157328
158329 def dataset_generator (self , batch_size = 32 , augment = False ):
330+ """
331+ dataset_generator
332+
333+ Create the Data Loader Pipeline and Return a Generator to Generate Datsets
334+
335+ Args:
336+ batch_size (int, optional): Batch Size. Defaults to 32.
337+ augment (bool, optional): Enable/Disable Augmentation. Defaults to False.
338+
339+ Returns:
340+ Tf.Data Generator: Dataset Generator
341+ """
159342 self .BATCH_SIZE = batch_size
160343
161344 dataset = self .dataset_files .map (
@@ -174,6 +357,16 @@ def dataset_generator(self, batch_size=32, augment=False):
174357 return dataset
175358
176359 def visualize_batch (self , augment = True ) -> None :
360+ """
361+ visualize_batch
362+
363+ Dataset Sample Visualization
364+ - Supports Augmentation
365+ - Automatically Adjusts for Grayscale Images
366+
367+ Args:
368+ augment (bool, optional): Enable/Disable Augmentation. Defaults to True.
369+ """
177370 if self .NUM_CHANNELS == 1 :
178371 cmap = "gray"
179372 else :
0 commit comments