Skip to content

Commit dab3a49

Browse files
committed
wav2vec2 finetuning in PyTorch for Audio Recognition
1 parent 8dc94a3 commit dab3a49

File tree

12 files changed

+334
-125
lines changed

12 files changed

+334
-125
lines changed

.vscode/launch.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"request": "launch",
1111
"program": "${file}",
1212
"console": "integratedTerminal",
13-
"justMyCode": false
13+
"justMyCode": false,
14+
"subProcess": true,
1415
}
1516
]
1617
}

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
## [1.1.1] - 2022-09-26
2+
### Changed
3+
- Included `self._executor` as generator in `mltu.dataProvider.DataProvider` object, to enable functionality to modify batch preprocessing without changing original code
4+
- Introduced changes in `mltu.torch.dataProvider.py` to handle data in multiprocessing and multithreading modes, for faster preprocessing while torch models
5+
- Modified `mltu.transformers.AudioPadding` object, to work with batches of raw audio data
6+
7+
### Added
8+
- Created tutorial `10_wav2vec2_torch` (Audio to Text model) that shows how to train wav2vec2 model with mltu
9+
10+
111
## [1.1.0] - 2022-08-28
212
### Changed
313
- Changed `mltu.transformers.SpectrogramPadding` object, to pad spectrogram end with zeros instead of start
@@ -10,6 +20,7 @@
1020
- Created `mltu.tensorflow.transformer.callbacks` module, that contains `EncDecSplitCallback` callback, to split Transformer model into separate encoder and decoder models
1121
- Created `mltu.tensorflow.transformer.utils` module, that contains `MaskedLoss` loss and `MaskedAccuracy` metric, used for training Transformer models
1222

23+
1324
## [1.0.15] - 2022-07-15
1425
### Changed
1526
- Fixed bug in `mltu.dataProvider.DataProvider` to work with `batch_postprocessors`.

Tutorials/10_wav2vec2_torch/configs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ def __init__(self):
1010
"Models/10_wav2vec2_torch",
1111
datetime.strftime(datetime.now(), "%Y%m%d%H%M"),
1212
)
13-
self.batch_size = 6
13+
self.batch_size = 8
1414
self.train_epochs = 60
1515
self.train_workers = 20
1616

17-
self.init_lr = 1.0e-7
17+
self.init_lr = 1.0e-8
1818
self.lr_after_warmup = 1e-05
1919
self.final_lr = 5e-06
20-
self.warmup_epochs = 5
20+
self.warmup_epochs = 10
2121
self.decay_epochs = 40
22+
self.weight_decay = 0.005
23+
self.mixed_precision = True
2224

2325
self.max_audio_length = 246000
2426
self.max_label_length = 256

Tutorials/10_wav2vec2_torch/test.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import cv2
2-
import typing
31
import numpy as np
42

53
from mltu.inferenceModel import OnnxInferenceModel
64
from mltu.utils.text_utils import ctc_decoder, get_cer, get_wer
7-
from mltu.preprocessors import AudioReader
85

96
class Wav2vec2(OnnxInferenceModel):
107
def __init__(self, *args, **kwargs):
@@ -21,48 +18,19 @@ def predict(self, audio: np.ndarray):
2118
return text
2219

2320
if __name__ == "__main__":
21+
import librosa
2422
import pandas as pd
2523
from tqdm import tqdm
26-
import onnxruntime as ort
2724

28-
# model_path = "Models/11_wav2vec2_torch/202309131152/model.onnx"
29-
# session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
30-
31-
# audio_len = 246000
32-
# # Prepare input data (replace 'input' with the actual input name)
33-
# input_data = {'input': np.random.randn(1, audio_len).astype(np.float32)}
34-
35-
# # Run inference
36-
# output = session.run(None, input_data)
37-
38-
model = Wav2vec2(model_path="Models/11_wav2vec2_torch/202309141138/model.onnx")
25+
model = Wav2vec2(model_path="Models/10_wav2vec2_torch/202309171434/model.onnx")
3926

4027
# The list of multiple [audio_path, label] for validation
41-
val_dataset = pd.read_csv("Models/11_wav2vec2_torch/202309141138/val.csv").values.tolist()
42-
43-
44-
# model.vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
45-
audioReader = AudioReader(sample_rate=16000)
46-
47-
48-
# dataset_path = "Datasets/LJSpeech-1.1"
49-
# metadata_path = dataset_path + "/metadata.csv"
50-
# wavs_path = dataset_path + "/wavs/"
51-
52-
# # Read metadata file and parse it
53-
# metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
54-
# dataset = []
55-
# # vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
56-
# for file_name, transcription, normalized_transcription in metadata_df.values.tolist():
57-
# path = f"Datasets/LJSpeech-1.1/wavs/{file_name}.wav"
58-
# new_label = "".join([l for l in normalized_transcription.lower() if l in model.vocab])
59-
# dataset.append([path, new_label])
60-
28+
val_dataset = pd.read_csv("Models/10_wav2vec2_torch/202309171434/val.csv").values.tolist()
6129

6230
accum_cer, accum_wer = [], []
6331
pbar = tqdm(val_dataset)
6432
for vaw_path, label in pbar:
65-
audio, label = audioReader(vaw_path, label)
33+
audio, sr = librosa.load(vaw_path, sr=16000)
6634

6735
prediction_text = model.predict(audio)
6836

Tutorials/10_wav2vec2_torch/train.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -67,19 +67,21 @@ def download_and_unzip(url, extract_to="Datasets", chunk_size=1024*1024):
6767
LabelIndexer(vocab),
6868
LabelPadding(max_word_length=configs.max_label_length, padding_value=len(vocab)),
6969
],
70-
use_cache=True,
71-
use_multiprocessing=False,
70+
use_cache=False,
7271
batch_postprocessors=[
7372
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True)
74-
]
73+
],
74+
use_multiprocessing=True,
75+
max_queue_size=10,
76+
workers=64,
7577
)
76-
7778
train_dataProvider, test_dataProvider = data_provider.split(split=0.9)
78-
train_dataProvider.augmentors = [
79-
RandomAudioNoise(),
80-
RandomAudioPitchShift(),
81-
RandomAudioTimeStretch()
82-
]
79+
80+
# train_dataProvider.augmentors = [
81+
# RandomAudioNoise(),
82+
# RandomAudioPitchShift(),
83+
# RandomAudioTimeStretch()
84+
# ]
8385

8486
vocab = sorted(vocab)
8587
configs.vocab = vocab
@@ -90,17 +92,11 @@ class CustomWav2Vec2Model(nn.Module):
9092
def __init__(self, hidden_states, dropout_rate=0.2, **kwargs):
9193
super(CustomWav2Vec2Model, self).__init__( **kwargs)
9294
pretrained_name = "facebook/wav2vec2-base-960h"
93-
self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_name).wav2vec2
94-
# self.model.freeze_feature_encoder()
95-
self.dropout = nn.Dropout(p=dropout_rate)
96-
self.linear = nn.Linear(self.model.config.hidden_size, hidden_states)
95+
self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_name, vocab_size=hidden_states, ignore_mismatched_sizes=True)
96+
self.model.freeze_feature_encoder() # this part does not need to be fine-tuned
9797

9898
def forward(self, inputs):
99-
output = self.model(inputs, attention_mask=None).last_hidden_state
100-
# Apply dropout
101-
output = self.dropout(output)
102-
# Apply linear layer
103-
output = self.linear(output)
99+
output = self.model(inputs, attention_mask=None).logits
104100
# Apply softmax
105101
output = F.log_softmax(output, -1)
106102
return output
@@ -118,6 +114,7 @@ def forward(self, inputs):
118114
decay_epochs=configs.decay_epochs,
119115
final_lr=configs.final_lr,
120116
initial_lr=configs.init_lr,
117+
verbose=True,
121118
)
122119
tb_callback = TensorBoard(configs.model_path + "/logs")
123120
earlyStopping = EarlyStopping(monitor="val_CER", patience=16, mode="min", verbose=1)
@@ -133,12 +130,13 @@ def forward(self, inputs):
133130
# create model object that will handle training and testing of the network
134131
model = Model(
135132
custom_model,
136-
loss = CTCLoss(blank=len(configs.vocab)),
137-
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=configs.init_lr, weight_decay=1e-5),
133+
loss = CTCLoss(blank=len(configs.vocab), zero_infinity=True),
134+
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=configs.init_lr, weight_decay=configs.weight_decay),
138135
metrics=[
139136
CERMetric(configs.vocab),
140137
WERMetric(configs.vocab)
141138
],
139+
mixed_precision=configs.mixed_precision,
142140
)
143141

144142
# Save training and validation datasets as csv files

Tutorials/10_wav2vec2_torch/train_tf.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
from configs import ModelConfigs
2525

2626
configs = ModelConfigs()
27-
from transformers import TFWav2Vec2Model
27+
from transformers import TFWav2Vec2ForCTC
2828
from mltu.preprocessors import AudioReader
2929

3030

31-
train_dataset = pd.read_csv("Models/11_wav2vec2_torch/202309141138/train.csv").values.tolist()
32-
validation_dataset = pd.read_csv("Models/11_wav2vec2_torch/202309141138/val.csv").values.tolist()
31+
train_dataset = pd.read_csv("Models/10_wav2vec2_torch/202309171434/train.csv").values.tolist()
32+
validation_dataset = pd.read_csv("Models/10_wav2vec2_torch/202309171434/val.csv").values.tolist()
3333

3434
# Create a data provider for the dataset
3535
train_dataProvider = DataProvider(
@@ -71,28 +71,24 @@
7171
use_cache=True,
7272
)
7373

74-
class TFWav2Vec2ForCTC(layers.Layer):
75-
def __init__(self, output_dim, dropout_rate=0.2, **kwargs):
74+
class CustomWav2Vec2Model(layers.Layer):
75+
def __init__(self, output_dim, **kwargs):
7676
super().__init__(**kwargs)
7777

78-
self.wav2vec2 = TFWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
79-
self.dropout = layers.Dropout(dropout_rate)
80-
self.final_layer = layers.Dense(output_dim, activation="softmax")
78+
pretrained_name = "facebook/wav2vec2-base-960h"
79+
self.model = TFWav2Vec2ForCTC.from_pretrained(pretrained_name, vocab_size=output_dim, ignore_mismatched_sizes=True)
80+
self.model.freeze_feature_encoder() # https://huggingface.co/blog/fine-tune-wav2vec2-english
8181

8282
def __call__(self, inputs):
83-
outputs = self.wav2vec2(inputs)
83+
outputs = self.model(inputs)
8484

85-
hidden_states = outputs.last_hidden_state
86-
87-
dropout = self.dropout(hidden_states)
88-
89-
final_state = self.final_layer(dropout)
85+
final_state = tf.nn.softmax(outputs.logits, axis=-1)
9086

9187
return final_state
9288

9389
custom_model = tf.keras.Sequential([
9490
layers.Input(shape=(None,), name="input", dtype=tf.float32),
95-
TFWav2Vec2ForCTC(len(configs.vocab)+1, dropout_rate=0.2)
91+
CustomWav2Vec2Model(len(configs.vocab)+1)
9692
])
9793

9894
for data in train_dataProvider:
@@ -105,7 +101,7 @@ def __call__(self, inputs):
105101

106102
# Compile the model and print summary
107103
custom_model.compile(
108-
optimizer=tf.keras.optimizers.AdamW(learning_rate=configs.init_lr, weight_decay=1e-5),
104+
optimizer=tf.keras.optimizers.AdamW(learning_rate=configs.init_lr, weight_decay=configs.weight_decay),
109105
loss=CTCloss(),
110106
metrics=[
111107
CERMetric(vocabulary=configs.vocab),

mltu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "1.1.0"
1+
__version__ = "1.1.1"
22

33
from .annotations.images import Image
44
from .annotations.images import CVImage

mltu/dataProvider.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,21 @@ def get_batch_annotations(self, index: int) -> typing.List:
205205

206206
return batch_annotations
207207

208+
def start_executor(self) -> None:
209+
""" Start the executor to process data"""
210+
def executor(batch_data):
211+
for data in batch_data:
212+
yield self.process_data(data)
213+
214+
if not hasattr(self, "_executor"):
215+
self._executor = executor
216+
208217
def __iter__(self):
209218
""" Create a generator that iterate over the Sequence."""
210-
for item in (self[i] for i in range(len(self))):
211-
yield item
219+
self.start_executor()
220+
for index in range(len(self)):
221+
results = self[index]
222+
yield results
212223

213224
def process_data(self, batch_data):
214225
""" Process data batch of data """
@@ -250,19 +261,22 @@ def process_data(self, batch_data):
250261
return data, annotation
251262

252263
def __getitem__(self, index: int):
253-
""" Returns a batch of data by batch index"""
264+
""" Returns a batch of processed data by index
265+
266+
Args:
267+
index (int): index of batch
268+
269+
Returns:
270+
tuple: batch of data and batch of annotations
271+
"""
254272
dataset_batch = self.get_batch_annotations(index)
255273

256274
# First read and preprocess the batch data
257275
batch_data, batch_annotations = [], []
258-
for index, batch in enumerate(dataset_batch):
259-
260-
data, annotation = self.process_data(batch)
261-
276+
for data, annotation in self._executor(dataset_batch):
262277
if data is None or annotation is None:
263278
self.logger.warning("Data or annotation is None, skipping.")
264279
continue
265-
266280
batch_data.append(data)
267281
batch_annotations.append(annotation)
268282

@@ -272,4 +286,4 @@ def __getitem__(self, index: int):
272286

273287
return batch_data, batch_annotations
274288

275-
return np.array(batch_data), np.array(batch_annotations)
289+
return np.array(batch_data), np.array(batch_annotations)

mltu/torch/callbacks.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ class WarmupCosineDecay(Callback):
440440
decay_epochs (int): Number of decay epochs
441441
initial_lr (float, optional): Initial learning rate. Defaults to 0.0.
442442
verbose (bool, optional): Whether to print learning rate. Defaults to False.
443+
warmup_steps (int, optional): Number of warmup steps. Defaults to None.
443444
"""
444445
def __init__(
445446
self,
@@ -448,6 +449,7 @@ def __init__(
448449
warmup_epochs: int,
449450
decay_epochs: int,
450451
initial_lr: float=0.0,
452+
warmup_steps: int=None,
451453
verbose=False
452454
) -> None:
453455
super(WarmupCosineDecay, self).__init__()
@@ -456,24 +458,43 @@ def __init__(
456458
self.warmup_epochs = warmup_epochs
457459
self.decay_epochs = decay_epochs
458460
self.initial_lr = initial_lr
461+
self.warmup_steps = warmup_steps
459462
self.verbose = verbose
463+
self.step = None
464+
465+
self.warmup_lrs = np.linspace(self.initial_lr, self.lr_after_warmup, self.warmup_epochs)
466+
if warmup_steps:
467+
self.step = 0
468+
self.warmup_epochs = 0
469+
self.warmup_lrs = np.linspace(self.initial_lr, self.lr_after_warmup, warmup_steps)
460470

461471
def on_epoch_begin(self, epoch: int, logs: dict=None):
462472
""" Adjust learning rate at the beginning of each epoch """
463473

474+
if self.warmup_steps:
475+
return logs
476+
464477
if epoch >= self.warmup_epochs + self.decay_epochs:
465478
return logs
466479

467480
if epoch <= self.warmup_epochs:
468-
lr = np.linspace(self.initial_lr, self.lr_after_warmup, 5)[epoch-1]
481+
lr = self.warmup_lrs[epoch-1]
469482
else:
470483
progress = (epoch - self.warmup_epochs) / self.decay_epochs
471484
lr = self.final_lr + 0.5 * (self.lr_after_warmup - self.final_lr) * (1 + np.cos(np.pi * progress))
472485

473486
self.model.optimizer.param_groups[0]["lr"] = lr
474487

475488
if self.verbose:
476-
print(f"Epoch {epoch + 1} - Learning Rate: {lr}")
489+
self.logger.info(f"Epoch {epoch} - Learning Rate: {lr}")
490+
491+
def on_train_batch_begin(self, batch: int, logs: dict=None):
492+
if self.warmup_steps and self.step is not None:
493+
if self.step < self.warmup_steps:
494+
self.model.optimizer.param_groups[0]["lr"] = self.warmup_lrs[self.step]
495+
self.step += 1
496+
else:
497+
self.step = None
477498

478499
def on_epoch_end(self, epoch: int, logs: dict=None):
479500
logs = logs or {}

0 commit comments

Comments
 (0)