Skip to content

Commit 5b42eaa

Browse files
committed
styling
1 parent cddaace commit 5b42eaa

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

mltu/dataProvider.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@
1414

1515
class DataProvider:
1616
def __init__(
17-
self,
18-
dataset: typing.Union[str, list, pd.DataFrame],
19-
data_preprocessors: typing.List[typing.Callable] = None,
20-
batch_size: int = 4,
21-
shuffle: bool = True,
22-
initial_epoch: int = 1,
23-
augmentors: typing.List[Augmentor] = None,
24-
transformers: typing.List[Transformer] = None,
25-
skip_validation: bool = True,
26-
limit: int = None,
27-
use_cache: bool = False,
28-
log_level: int = logging.INFO,
29-
) -> None:
17+
self,
18+
dataset: typing.Union[str, list, pd.DataFrame],
19+
data_preprocessors: typing.List[typing.Callable] = None,
20+
batch_size: int = 4,
21+
shuffle: bool = True,
22+
initial_epoch: int = 1,
23+
augmentors: typing.List[Augmentor] = None,
24+
transformers: typing.List[Transformer] = None,
25+
skip_validation: bool = True,
26+
limit: int = None,
27+
use_cache: bool = False,
28+
log_level: int = logging.INFO,
29+
) -> None:
3030
""" Standardised object for providing data to a model while training.
3131
3232
Attributes:
@@ -91,8 +91,6 @@ def augmentors(self, augmentors: typing.List[Augmentor]):
9191
else:
9292
self.logger.warning(f"Augmentor {augmentor} is not an instance of Augmentor.")
9393

94-
return self._augmentors
95-
9694
@property
9795
def transformers(self) -> typing.List[Transformer]:
9896
""" Return transformers """
@@ -111,8 +109,6 @@ def transformers(self, transformers: typing.List[Transformer]):
111109
else:
112110
self.logger.warning(f"Transformer {transformer} is not an instance of Transformer.")
113111

114-
return self._transformers
115-
116112
@property
117113
def epoch(self) -> int:
118114
""" Return Current Epoch"""
@@ -131,7 +127,7 @@ def on_epoch_end(self):
131127

132128
# Remove any samples that were marked for removal
133129
for remove in self._on_epoch_end_remove:
134-
self.logger.warn(f"Removing {remove} from dataset.")
130+
self.logger.warning(f"Removing {remove} from dataset.")
135131
self._dataset.remove(remove)
136132
self._on_epoch_end_remove = []
137133

@@ -176,7 +172,7 @@ def split(self, split: float = 0.9, shuffle: bool = True) -> typing.Tuple[typing
176172

177173
return train_data_provider, val_data_provider
178174

179-
def to_csv(self, path: str, index: bool=False) -> None:
175+
def to_csv(self, path: str, index: bool = False) -> None:
180176
""" Save the dataset to a csv file
181177
182178
Args:
@@ -230,8 +226,8 @@ def process_data(self, batch_data):
230226

231227
# Then augment, transform and postprocess the batch data
232228
for objects in [self._augmentors, self._transformers]:
233-
for object in objects:
234-
data, annotation = object(data, annotation)
229+
for _object in objects:
230+
data, annotation = _object(data, annotation)
235231

236232
# Convert to numpy array if not already
237233
if not isinstance(data, np.ndarray):
@@ -261,4 +257,4 @@ def __getitem__(self, index: int):
261257
batch_data.append(data)
262258
batch_annotations.append(annotation)
263259

264-
return np.array(batch_data), np.array(batch_annotations)
260+
return np.array(batch_data), np.array(batch_annotations)

0 commit comments

Comments
 (0)