Skip to content

Commit 6492944

Browse files
committed
Merge branch 'pythonlessons:main' into main
2 parents dacabf5 + 0ec0bcc commit 6492944

File tree

7 files changed

+51
-47
lines changed

7 files changed

+51
-47
lines changed

Tutorials/05_sound_to_text/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from mltu.configs import BaseModelConfigs
55

6+
67
class ModelConfigs(BaseModelConfigs):
78
def __init__(self):
89
super().__init__()

mltu/configs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
import os
22
import yaml
33

4+
45
class BaseModelConfigs:
56
def __init__(self):
67
self.model_path = None
78

89
def serialize(self):
9-
# get object attributes
10-
return self.__dict__
10+
class_attributes = {key: value
11+
for (key, value)
12+
in type(self).__dict__.items()
13+
if key not in ['__module__', '__init__', '__doc__', '__annotations__']}
14+
instance_attributes = self.__dict__
15+
16+
# first init with class attributes then apply instance attributes overwriting any existing duplicate attributes
17+
all_attributes = class_attributes.copy()
18+
all_attributes.update(instance_attributes)
19+
20+
return all_attributes
1121

12-
def save(self, name: str="configs.yaml"):
22+
def save(self, name: str = "configs.yaml"):
1323
if self.model_path is None:
1424
raise Exception("Model path is not specified")
1525

mltu/dataProvider.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@
99
from .transformers import Transformer
1010

1111
import logging
12-
logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s: %(message)s")
1312

1413

1514
class DataProvider:
1615
def __init__(
17-
self,
18-
dataset: typing.Union[str, list, pd.DataFrame],
19-
data_preprocessors: typing.List[typing.Callable] = None,
20-
batch_size: int = 4,
21-
shuffle: bool = True,
22-
initial_epoch: int = 1,
23-
augmentors: typing.List[Augmentor] = None,
24-
transformers: typing.List[Transformer] = None,
25-
skip_validation: bool = True,
26-
limit: int = None,
27-
use_cache: bool = False,
28-
log_level: int = logging.INFO,
29-
) -> None:
16+
self,
17+
dataset: typing.Union[str, list, pd.DataFrame],
18+
data_preprocessors: typing.List[typing.Callable] = None,
19+
batch_size: int = 4,
20+
shuffle: bool = True,
21+
initial_epoch: int = 1,
22+
augmentors: typing.List[Augmentor] = None,
23+
transformers: typing.List[Transformer] = None,
24+
skip_validation: bool = True,
25+
limit: int = None,
26+
use_cache: bool = False,
27+
log_level: int = logging.INFO,
28+
) -> None:
3029
""" Standardised object for providing data to a model while training.
3130
3231
Attributes:
@@ -61,7 +60,7 @@ def __init__(
6160

6261
# Validate dataset
6362
if not skip_validation:
64-
self._dataset = self.validate(dataset, skip_validation, limit)
63+
self._dataset = self.validate(dataset)
6564
else:
6665
self.logger.info("Skipping Dataset validation...")
6766

@@ -91,8 +90,6 @@ def augmentors(self, augmentors: typing.List[Augmentor]):
9190
else:
9291
self.logger.warning(f"Augmentor {augmentor} is not an instance of Augmentor.")
9392

94-
return self._augmentors
95-
9693
@property
9794
def transformers(self) -> typing.List[Transformer]:
9895
""" Return transformers """
@@ -111,8 +108,6 @@ def transformers(self, transformers: typing.List[Transformer]):
111108
else:
112109
self.logger.warning(f"Transformer {transformer} is not an instance of Transformer.")
113110

114-
return self._transformers
115-
116111
@property
117112
def epoch(self) -> int:
118113
""" Return Current Epoch"""
@@ -131,28 +126,28 @@ def on_epoch_end(self):
131126

132127
# Remove any samples that were marked for removal
133128
for remove in self._on_epoch_end_remove:
134-
self.logger.warn(f"Removing {remove} from dataset.")
129+
self.logger.warning(f"Removing {remove} from dataset.")
135130
self._dataset.remove(remove)
136131
self._on_epoch_end_remove = []
137132

138-
def validate_list_dataset(self, dataset: list, skip_validation: bool = False) -> list:
133+
def validate_list_dataset(self, dataset: list) -> list:
139134
""" Validate a list dataset """
140135
validated_data = [data for data in tqdm(dataset, desc="Validating Dataset") if os.path.exists(data[0])]
141136
if not validated_data:
142137
raise FileNotFoundError("No valid data found in dataset.")
143138

144139
return validated_data
145140

146-
def validate(self, dataset: typing.Union[str, list, pd.DataFrame], skip_validation: bool) -> list:
141+
def validate(self, dataset: typing.Union[str, list, pd.DataFrame]) -> typing.Union[list, str]:
147142
""" Validate the dataset and return the dataset """
148143

149144
if isinstance(dataset, str):
150145
if os.path.exists(dataset):
151146
return dataset
152147
elif isinstance(dataset, list):
153-
return self.validate_list_dataset(dataset, skip_validation)
148+
return self.validate_list_dataset(dataset)
154149
elif isinstance(dataset, pd.DataFrame):
155-
return self.validate_list_dataset(dataset.values.tolist(), skip_validation)
150+
return self.validate_list_dataset(dataset.values.tolist())
156151
else:
157152
raise TypeError("Dataset must be a path, list or pandas dataframe.")
158153

@@ -176,7 +171,7 @@ def split(self, split: float = 0.9, shuffle: bool = True) -> typing.Tuple[typing
176171

177172
return train_data_provider, val_data_provider
178173

179-
def to_csv(self, path: str, index: bool=False) -> None:
174+
def to_csv(self, path: str, index: bool = False) -> None:
180175
""" Save the dataset to a csv file
181176
182177
Args:
@@ -230,8 +225,8 @@ def process_data(self, batch_data):
230225

231226
# Then augment, transform and postprocess the batch data
232227
for objects in [self._augmentors, self._transformers]:
233-
for object in objects:
234-
data, annotation = object(data, annotation)
228+
for _object in objects:
229+
data, annotation = _object(data, annotation)
235230

236231
# Convert to numpy array if not already
237232
if not isinstance(data, np.ndarray):
@@ -261,4 +256,4 @@ def __getitem__(self, index: int):
261256
batch_data.append(data)
262257
batch_annotations.append(annotation)
263258

264-
return np.array(batch_data), np.array(batch_annotations)
259+
return np.array(batch_data), np.array(batch_annotations)

mltu/preprocessors.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from . import Image
1111

12-
logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s: %(message)s")
13-
matplotlib.interactive(False)
1412

1513

1614
class ImageReader:
@@ -56,17 +54,20 @@ class WavReader:
5654
frame_step (int): Step size between frames in samples.
5755
fft_length (int): Number of FFT components.
5856
"""
57+
5958
def __init__(
60-
self,
61-
frame_length: int = 256,
62-
frame_step: int = 160,
63-
fft_length: int = 384,
64-
*args, **kwargs
65-
) -> None:
59+
self,
60+
frame_length: int = 256,
61+
frame_step: int = 160,
62+
fft_length: int = 384,
63+
*args, **kwargs
64+
) -> None:
6665
self.frame_length = frame_length
6766
self.frame_step = frame_step
6867
self.fft_length = fft_length
6968

69+
matplotlib.interactive(False)
70+
7071
@staticmethod
7172
def get_spectrogram(wav_path: str, frame_length: int, frame_step: int, fft_length: int) -> np.ndarray:
7273
"""Compute the spectrogram of a WAV file

mltu/tensorflow/__init__.py

Whitespace-only changes.

mltu/transformers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55
from . import Image
66

77
import logging
8-
logging.basicConfig(format="%(asctime)s %(levelname)s %(name)s: %(message)s")
9-
logger = logging.getLogger(__name__)
10-
logger.setLevel(logging.INFO)
118

129

1310
class Transformer:

requirements.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@ tqdm
33
pandas
44
numpy
55
opencv-python
6-
Pillow==9.4.0
7-
onnxruntime # onnxruntime-gpu for GPU support
8-
librosa==0.9.2
6+
Pillow>=9.4.0
7+
onnxruntime>=1.15.0 # onnxruntime-gpu for GPU support
8+
librosa>=0.9.2
99
matplotlib
10-
onnx==1.14.0
11-
tf2onnx==1.14.0
10+
onnx>=1.14.0
11+
tf2onnx>=1.14.0

0 commit comments

Comments
 (0)