1+ import tensorflow as tf
2+ try : [tf .config .experimental .set_memory_growth (gpu , True ) for gpu in tf .config .experimental .list_physical_devices ("GPU" )]
3+ except : pass
4+ tf .keras .mixed_precision .set_global_policy ('mixed_float16' ) # mixed precission training for faster training time
5+
6+ import os
7+ import tarfile
8+ import pandas as pd
9+ from tqdm import tqdm
10+ from urllib .request import urlopen
11+ from io import BytesIO
12+
13+ from keras .callbacks import EarlyStopping , ModelCheckpoint , ReduceLROnPlateau , TensorBoard
14+ from mltu .preprocessors import WavReader
15+
16+ from mltu .tensorflow .dataProvider import DataProvider
17+ from mltu .transformers import LabelIndexer , LabelPadding , SpectrogramPadding
18+ from mltu .tensorflow .losses import CTCloss
19+ from mltu .tensorflow .callbacks import Model2onnx , TrainLogger
20+ from mltu .tensorflow .metrics import CERMetric , WERMetric
21+
22+ from model import train_model
23+ from configs import ModelConfigs
24+
25+
26+ def download_and_unzip (url , extract_to = "Datasets" , chunk_size = 1024 * 1024 ):
27+ http_response = urlopen (url )
28+
29+ data = b""
30+ iterations = http_response .length // chunk_size + 1
31+ for _ in tqdm (range (iterations )):
32+ data += http_response .read (chunk_size )
33+
34+ tarFile = tarfile .open (fileobj = BytesIO (data ), mode = "r|bz2" )
35+ tarFile .extractall (path = extract_to )
36+ tarFile .close ()
37+
38+
39+ dataset_path = os .path .join ("Datasets" , "LJSpeech-1.1" )
40+ if not os .path .exists (dataset_path ):
41+ download_and_unzip ("https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2" , extract_to = "Datasets" )
42+
43+ dataset_path = "Datasets/LJSpeech-1.1"
44+ metadata_path = dataset_path + "/metadata.csv"
45+ wavs_path = dataset_path + "/wavs/"
46+
47+ # Read metadata file and parse it
48+ metadata_df = pd .read_csv (metadata_path , sep = "|" , header = None , quoting = 3 )
49+ metadata_df .columns = ["file_name" , "transcription" , "normalized_transcription" ]
50+ metadata_df = metadata_df [["file_name" , "normalized_transcription" ]]
51+
52+ # structure the dataset where each row is a list of [wav_file_path, sound transcription]
53+ dataset = [[f"Datasets/LJSpeech-1.1/wavs/{ file } .wav" , label .lower ()] for file , label in metadata_df .values .tolist ()]
54+
55+ # Create a ModelConfigs object to store model configurations
56+ configs = ModelConfigs ()
57+ configs .save ()
58+
59+ # Create a data provider for the dataset
60+ data_provider = DataProvider (
61+ dataset = dataset ,
62+ skip_validation = True ,
63+ batch_size = configs .batch_size ,
64+ data_preprocessors = [
65+ WavReader (frame_length = configs .frame_length , frame_step = configs .frame_step , fft_length = configs .fft_length ),
66+ ],
67+ transformers = [
68+ LabelIndexer (configs .vocab ),
69+ ],
70+ batch_postprocessors = [
71+ SpectrogramPadding (padding_value = 0 , use_on_batch = True ),
72+ LabelPadding (padding_value = len (configs .vocab ), use_on_batch = True ),
73+ ],
74+ )
75+
76+ # Split the dataset into training and validation sets
77+ train_data_provider , val_data_provider = data_provider .split (split = 0.9 )
78+
79+ # Creating TensorFlow model architecture
80+ model = train_model (
81+ input_dim = (None , 193 ),
82+ output_dim = len (configs .vocab ),
83+ dropout = 0.5
84+ )
85+
86+ # Compile the model and print summary
87+ model .compile (
88+ optimizer = tf .keras .optimizers .Adam (learning_rate = configs .learning_rate ),
89+ loss = CTCloss (),
90+ metrics = [
91+ CERMetric (vocabulary = configs .vocab ),
92+ WERMetric (vocabulary = configs .vocab )
93+ ],
94+ run_eagerly = False
95+ )
96+ model .summary (line_length = 110 )
97+
98+ # Define callbacks
99+ earlystopper = EarlyStopping (monitor = "val_CER" , patience = 20 , verbose = 1 , mode = "min" )
100+ checkpoint = ModelCheckpoint (f"{ configs .model_path } /model.h5" , monitor = "val_CER" , verbose = 1 , save_best_only = True , mode = "min" )
101+ trainLogger = TrainLogger (configs .model_path )
102+ tb_callback = TensorBoard (f"{ configs .model_path } /logs" , update_freq = 1 )
103+ reduceLROnPlat = ReduceLROnPlateau (monitor = "val_CER" , factor = 0.8 , min_delta = 1e-10 , patience = 5 , verbose = 1 , mode = "auto" )
104+ model2onnx = Model2onnx (f"{ configs .model_path } /model.h5" )
105+
106+ # Train the model
107+ model .fit (
108+ train_data_provider ,
109+ validation_data = val_data_provider ,
110+ epochs = configs .train_epochs ,
111+ callbacks = [earlystopper , checkpoint , trainLogger , reduceLROnPlat , tb_callback , model2onnx ],
112+ workers = configs .train_workers ,
113+ )
114+
115+ # Save training and validation datasets as csv files
116+ train_data_provider .to_csv (os .path .join (configs .model_path , "train.csv" ))
117+ val_data_provider .to_csv (os .path .join (configs .model_path , "val.csv" ))
0 commit comments