99from .transformers import Transformer
1010
1111import logging
12- logging .basicConfig (format = "%(asctime)s %(levelname)s %(name)s: %(message)s" )
1312
1413
1514class DataProvider :
1615 def __init__ (
17- self ,
18- dataset : typing .Union [str , list , pd .DataFrame ],
19- data_preprocessors : typing .List [typing .Callable ] = None ,
20- batch_size : int = 4 ,
21- shuffle : bool = True ,
22- initial_epoch : int = 1 ,
23- augmentors : typing .List [Augmentor ] = None ,
24- transformers : typing .List [Transformer ] = None ,
25- skip_validation : bool = True ,
26- limit : int = None ,
27- use_cache : bool = False ,
28- log_level : int = logging .INFO ,
29- ) -> None :
16+ self ,
17+ dataset : typing .Union [str , list , pd .DataFrame ],
18+ data_preprocessors : typing .List [typing .Callable ] = None ,
19+ batch_size : int = 4 ,
20+ shuffle : bool = True ,
21+ initial_epoch : int = 1 ,
22+ augmentors : typing .List [Augmentor ] = None ,
23+ transformers : typing .List [Transformer ] = None ,
24+ skip_validation : bool = True ,
25+ limit : int = None ,
26+ use_cache : bool = False ,
27+ log_level : int = logging .INFO ,
28+ ) -> None :
3029 """ Standardised object for providing data to a model while training.
3130
3231 Attributes:
@@ -61,7 +60,7 @@ def __init__(
6160
6261 # Validate dataset
6362 if not skip_validation :
64- self ._dataset = self .validate (dataset , skip_validation , limit )
63+ self ._dataset = self .validate (dataset )
6564 else :
6665 self .logger .info ("Skipping Dataset validation..." )
6766
@@ -91,8 +90,6 @@ def augmentors(self, augmentors: typing.List[Augmentor]):
9190 else :
9291 self .logger .warning (f"Augmentor { augmentor } is not an instance of Augmentor." )
9392
94- return self ._augmentors
95-
9693 @property
9794 def transformers (self ) -> typing .List [Transformer ]:
9895 """ Return transformers """
@@ -111,8 +108,6 @@ def transformers(self, transformers: typing.List[Transformer]):
111108 else :
112109 self .logger .warning (f"Transformer { transformer } is not an instance of Transformer." )
113110
114- return self ._transformers
115-
116111 @property
117112 def epoch (self ) -> int :
118113 """ Return Current Epoch"""
@@ -131,28 +126,28 @@ def on_epoch_end(self):
131126
132127 # Remove any samples that were marked for removal
133128 for remove in self ._on_epoch_end_remove :
134- self .logger .warn (f"Removing { remove } from dataset." )
129+ self .logger .warning (f"Removing { remove } from dataset." )
135130 self ._dataset .remove (remove )
136131 self ._on_epoch_end_remove = []
137132
138- def validate_list_dataset (self , dataset : list , skip_validation : bool = False ) -> list :
133+ def validate_list_dataset (self , dataset : list ) -> list :
139134 """ Validate a list dataset """
140135 validated_data = [data for data in tqdm (dataset , desc = "Validating Dataset" ) if os .path .exists (data [0 ])]
141136 if not validated_data :
142137 raise FileNotFoundError ("No valid data found in dataset." )
143138
144139 return validated_data
145140
146- def validate (self , dataset : typing .Union [str , list , pd .DataFrame ], skip_validation : bool ) -> list :
141+ def validate (self , dataset : typing .Union [str , list , pd .DataFrame ]) -> typing . Union [ list , str ] :
147142 """ Validate the dataset and return the dataset """
148143
149144 if isinstance (dataset , str ):
150145 if os .path .exists (dataset ):
151146 return dataset
152147 elif isinstance (dataset , list ):
153- return self .validate_list_dataset (dataset , skip_validation )
148+ return self .validate_list_dataset (dataset )
154149 elif isinstance (dataset , pd .DataFrame ):
155- return self .validate_list_dataset (dataset .values .tolist (), skip_validation )
150+ return self .validate_list_dataset (dataset .values .tolist ())
156151 else :
157152 raise TypeError ("Dataset must be a path, list or pandas dataframe." )
158153
@@ -176,7 +171,7 @@ def split(self, split: float = 0.9, shuffle: bool = True) -> typing.Tuple[typing
176171
177172 return train_data_provider , val_data_provider
178173
179- def to_csv (self , path : str , index : bool = False ) -> None :
174+ def to_csv (self , path : str , index : bool = False ) -> None :
180175 """ Save the dataset to a csv file
181176
182177 Args:
@@ -230,8 +225,8 @@ def process_data(self, batch_data):
230225
231226 # Then augment, transform and postprocess the batch data
232227 for objects in [self ._augmentors , self ._transformers ]:
233- for object in objects :
234- data , annotation = object (data , annotation )
228+ for _object in objects :
229+ data , annotation = _object (data , annotation )
235230
236231 # Convert to numpy array if not already
237232 if not isinstance (data , np .ndarray ):
@@ -261,4 +256,4 @@ def __getitem__(self, index: int):
261256 batch_data .append (data )
262257 batch_annotations .append (annotation )
263258
264- return np .array (batch_data ), np .array (batch_annotations )
259+ return np .array (batch_data ), np .array (batch_annotations )
0 commit comments