Skip to content

Commit 30d5280

Browse files
committed
Merge branch 'feature/speechTransformer' into develop
2 parents dd93806 + dab3a49 commit 30d5280

File tree

23 files changed

+1354
-176
lines changed

23 files changed

+1354
-176
lines changed

.vscode/launch.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
"request": "launch",
1111
"program": "${file}",
1212
"console": "integratedTerminal",
13-
"justMyCode": false
13+
"justMyCode": false,
14+
"subProcess": true,
1415
}
1516
]
1617
}

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
## [1.1.1] - 2022-09-26
2+
### Changed
3+
- Included `self._executor` as generator in `mltu.dataProvider.DataProvider` object, to enable functionality to modify batch preprocessing without changing original code
4+
- Introduced changes in `mltu.torch.dataProvider.py` to handle data in multiprocessing and multithreading modes, for faster preprocessing while torch models
5+
- Modified `mltu.transformers.AudioPadding` object, to work with batches of raw audio data
6+
7+
### Added
8+
- Created tutorial `10_wav2vec2_torch` (Audio to Text model) that shows how to train wav2vec2 model with mltu
9+
10+
111
## [1.1.0] - 2022-08-28
212
### Changed
313
- Changed `mltu.transformers.SpectrogramPadding` object, to pad spectrogram end with zeros instead of start
@@ -10,6 +20,7 @@
1020
- Created `mltu.tensorflow.transformer.callbacks` module, that contains `EncDecSplitCallback` callback, to split Transformer model into separate encoder and decoder models
1121
- Created `mltu.tensorflow.transformer.utils` module, that contains `MaskedLoss` loss and `MaskedAccuracy` metric, used for training Transformer models
1222

23+
1324
## [1.0.15] - 2022-07-15
1425
### Changed
1526
- Fixed bug in `mltu.dataProvider.DataProvider` to work with `batch_postprocessors`.

Tutorials/09_translation_transformer/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def preprocess_inputs(data_batch, label_batch):
131131
validation_data=val_dataProvider,
132132
epochs=configs.train_epochs,
133133
callbacks=[
134+
earlystopper,
134135
warmupCosineDecay,
135136
checkpoint,
136137
tb_callback,
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
from datetime import datetime
3+
4+
from mltu.configs import BaseModelConfigs
5+
6+
class ModelConfigs(BaseModelConfigs):
7+
def __init__(self):
8+
super().__init__()
9+
self.model_path = os.path.join(
10+
"Models/10_wav2vec2_torch",
11+
datetime.strftime(datetime.now(), "%Y%m%d%H%M"),
12+
)
13+
self.batch_size = 8
14+
self.train_epochs = 60
15+
self.train_workers = 20
16+
17+
self.init_lr = 1.0e-8
18+
self.lr_after_warmup = 1e-05
19+
self.final_lr = 5e-06
20+
self.warmup_epochs = 10
21+
self.decay_epochs = 40
22+
self.weight_decay = 0.005
23+
self.mixed_precision = True
24+
25+
self.max_audio_length = 246000
26+
self.max_label_length = 256
27+
28+
self.vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
torch==1.13.1+cu117
2+
transformers==4.33.1
3+
onnx
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
3+
from mltu.inferenceModel import OnnxInferenceModel
4+
from mltu.utils.text_utils import ctc_decoder, get_cer, get_wer
5+
6+
class Wav2vec2(OnnxInferenceModel):
7+
def __init__(self, *args, **kwargs):
8+
super().__init__(*args, **kwargs)
9+
10+
def predict(self, audio: np.ndarray):
11+
12+
audio = np.expand_dims(audio, axis=0).astype(np.float32)
13+
14+
preds = self.model.run(None, {self.input_name: audio})[0]
15+
16+
text = ctc_decoder(preds, self.metadata["vocab"])[0]
17+
18+
return text
19+
20+
if __name__ == "__main__":
21+
import librosa
22+
import pandas as pd
23+
from tqdm import tqdm
24+
25+
model = Wav2vec2(model_path="Models/10_wav2vec2_torch/202309171434/model.onnx")
26+
27+
# The list of multiple [audio_path, label] for validation
28+
val_dataset = pd.read_csv("Models/10_wav2vec2_torch/202309171434/val.csv").values.tolist()
29+
30+
accum_cer, accum_wer = [], []
31+
pbar = tqdm(val_dataset)
32+
for vaw_path, label in pbar:
33+
audio, sr = librosa.load(vaw_path, sr=16000)
34+
35+
prediction_text = model.predict(audio)
36+
37+
cer = get_cer(prediction_text, label)
38+
wer = get_wer(prediction_text, label)
39+
40+
accum_cer.append(cer)
41+
accum_wer.append(wer)
42+
43+
pbar.set_description(f"Average CER: {np.average(accum_cer):.4f}, Average WER: {np.average(accum_wer):.4f}")
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
import tarfile
3+
import pandas as pd
4+
from tqdm import tqdm
5+
from io import BytesIO
6+
from urllib.request import urlopen
7+
8+
import torch
9+
from torch import nn
10+
from transformers import Wav2Vec2ForCTC
11+
import torch.nn.functional as F
12+
13+
from mltu.torch.model import Model
14+
from mltu.torch.losses import CTCLoss
15+
from mltu.torch.dataProvider import DataProvider
16+
from mltu.torch.metrics import CERMetric, WERMetric
17+
from mltu.torch.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, Model2onnx, WarmupCosineDecay
18+
from mltu.augmentors import RandomAudioNoise, RandomAudioPitchShift, RandomAudioTimeStretch
19+
20+
from mltu.preprocessors import AudioReader
21+
from mltu.transformers import LabelIndexer, LabelPadding, AudioPadding
22+
23+
from configs import ModelConfigs
24+
25+
configs = ModelConfigs()
26+
27+
28+
def download_and_unzip(url, extract_to="Datasets", chunk_size=1024*1024):
29+
http_response = urlopen(url)
30+
31+
data = b""
32+
iterations = http_response.length // chunk_size + 1
33+
for _ in tqdm(range(iterations)):
34+
data += http_response.read(chunk_size)
35+
36+
tarFile = tarfile.open(fileobj=BytesIO(data), mode="r|bz2")
37+
tarFile.extractall(path=extract_to)
38+
tarFile.close()
39+
40+
41+
dataset_path = os.path.join("Datasets", "LJSpeech-1.1")
42+
if not os.path.exists(dataset_path):
43+
download_and_unzip("https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2", extract_to="Datasets")
44+
45+
dataset_path = "Datasets/LJSpeech-1.1"
46+
metadata_path = dataset_path + "/metadata.csv"
47+
wavs_path = dataset_path + "/wavs/"
48+
49+
# Read metadata file and parse it
50+
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
51+
dataset = []
52+
vocab = [' ', "'", 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
53+
for file_name, transcription, normalized_transcription in metadata_df.values.tolist():
54+
path = f"Datasets/LJSpeech-1.1/wavs/{file_name}.wav"
55+
new_label = "".join([l for l in normalized_transcription.lower() if l in vocab])
56+
dataset.append([path, new_label])
57+
58+
# Create a data provider for the dataset
59+
data_provider = DataProvider(
60+
dataset=dataset,
61+
skip_validation=True,
62+
batch_size=configs.batch_size,
63+
data_preprocessors=[
64+
AudioReader(sample_rate=16000),
65+
],
66+
transformers=[
67+
LabelIndexer(vocab),
68+
LabelPadding(max_word_length=configs.max_label_length, padding_value=len(vocab)),
69+
],
70+
use_cache=False,
71+
batch_postprocessors=[
72+
AudioPadding(max_audio_length=configs.max_audio_length, padding_value=0, use_on_batch=True)
73+
],
74+
use_multiprocessing=True,
75+
max_queue_size=10,
76+
workers=64,
77+
)
78+
train_dataProvider, test_dataProvider = data_provider.split(split=0.9)
79+
80+
# train_dataProvider.augmentors = [
81+
# RandomAudioNoise(),
82+
# RandomAudioPitchShift(),
83+
# RandomAudioTimeStretch()
84+
# ]
85+
86+
vocab = sorted(vocab)
87+
configs.vocab = vocab
88+
configs.save()
89+
90+
91+
class CustomWav2Vec2Model(nn.Module):
92+
def __init__(self, hidden_states, dropout_rate=0.2, **kwargs):
93+
super(CustomWav2Vec2Model, self).__init__( **kwargs)
94+
pretrained_name = "facebook/wav2vec2-base-960h"
95+
self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_name, vocab_size=hidden_states, ignore_mismatched_sizes=True)
96+
self.model.freeze_feature_encoder() # this part does not need to be fine-tuned
97+
98+
def forward(self, inputs):
99+
output = self.model(inputs, attention_mask=None).logits
100+
# Apply softmax
101+
output = F.log_softmax(output, -1)
102+
return output
103+
104+
custom_model = CustomWav2Vec2Model(hidden_states = len(vocab)+1)
105+
106+
# put on cuda device if available
107+
if torch.cuda.is_available():
108+
custom_model = custom_model.cuda()
109+
110+
# create callbacks
111+
warmupCosineDecay = WarmupCosineDecay(
112+
lr_after_warmup=configs.lr_after_warmup,
113+
warmup_epochs=configs.warmup_epochs,
114+
decay_epochs=configs.decay_epochs,
115+
final_lr=configs.final_lr,
116+
initial_lr=configs.init_lr,
117+
verbose=True,
118+
)
119+
tb_callback = TensorBoard(configs.model_path + "/logs")
120+
earlyStopping = EarlyStopping(monitor="val_CER", patience=16, mode="min", verbose=1)
121+
modelCheckpoint = ModelCheckpoint(configs.model_path + "/model.pt", monitor="val_CER", mode="min", save_best_only=True, verbose=1)
122+
model2onnx = Model2onnx(
123+
saved_model_path=configs.model_path + "/model.pt",
124+
input_shape=(1, configs.max_audio_length),
125+
verbose=1,
126+
metadata={"vocab": configs.vocab},
127+
dynamic_axes={"input": {0: "batch_size", 1: "sequence_length"}, "output": {0: "batch_size", 1: "sequence_length"}}
128+
)
129+
130+
# create model object that will handle training and testing of the network
131+
model = Model(
132+
custom_model,
133+
loss = CTCLoss(blank=len(configs.vocab), zero_infinity=True),
134+
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=configs.init_lr, weight_decay=configs.weight_decay),
135+
metrics=[
136+
CERMetric(configs.vocab),
137+
WERMetric(configs.vocab)
138+
],
139+
mixed_precision=configs.mixed_precision,
140+
)
141+
142+
# Save training and validation datasets as csv files
143+
train_dataProvider.to_csv(os.path.join(configs.model_path, "train.csv"))
144+
test_dataProvider.to_csv(os.path.join(configs.model_path, "val.csv"))
145+
146+
model.fit(
147+
train_dataProvider,
148+
test_dataProvider,
149+
epochs=configs.train_epochs,
150+
callbacks=[
151+
warmupCosineDecay,
152+
tb_callback,
153+
earlyStopping,
154+
modelCheckpoint,
155+
model2onnx
156+
]
157+
)

0 commit comments

Comments
 (0)