Skip to content

Commit 5f900f1

Browse files
Edressonerogol
andauthored
Add XTTS Fine tuning gradio demo (#3296)
* Add XTTS FT demo data processing pipeline * Add training and inference columns * Uses tabs instead of columns * Fix demo freezing issue * Update demo * Convert stereo to mono * Bug fix on XTTS inference * Update gradio demo * Update gradio demo * Update gradio demo * Update gradio demo * Add parameters to be able to set then on colab demo * Add erros messages * Add intuitive error messages * Update * Add max_audio_length parameter * Add XTTS fine-tuner docs * Update XTTS finetuner docs * Delete trainer to freeze memory * Delete unused variables * Add gc.collect() * Update xtts.md --------- Co-authored-by: Eren Gölge <erogol@hotmail.com>
1 parent 6d1905c commit 5f900f1

File tree

7 files changed

+800
-1
lines changed

7 files changed

+800
-1
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
faster_whisper==0.9.0
2+
gradio==4.7.1
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
import gc
3+
import torchaudio
4+
import pandas
5+
from faster_whisper import WhisperModel
6+
from glob import glob
7+
8+
from tqdm import tqdm
9+
10+
import torch
11+
import torchaudio
12+
# torch.set_num_threads(1)
13+
14+
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
15+
16+
torch.set_num_threads(16)
17+
18+
19+
import os
20+
21+
audio_types = (".wav", ".mp3", ".flac")
22+
23+
24+
def list_audios(basePath, contains=None):
25+
# return the set of files that are valid
26+
return list_files(basePath, validExts=audio_types, contains=contains)
27+
28+
def list_files(basePath, validExts=None, contains=None):
29+
# loop over the directory structure
30+
for (rootDir, dirNames, filenames) in os.walk(basePath):
31+
# loop over the filenames in the current directory
32+
for filename in filenames:
33+
# if the contains string is not none and the filename does not contain
34+
# the supplied string, then ignore the file
35+
if contains is not None and filename.find(contains) == -1:
36+
continue
37+
38+
# determine the file extension of the current file
39+
ext = filename[filename.rfind("."):].lower()
40+
41+
# check to see if the file is an audio and should be processed
42+
if validExts is None or ext.endswith(validExts):
43+
# construct the path to the audio and yield it
44+
audioPath = os.path.join(rootDir, filename)
45+
yield audioPath
46+
47+
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):
48+
audio_total_size = 0
49+
# make sure that ooutput file exists
50+
os.makedirs(out_path, exist_ok=True)
51+
52+
# Loading Whisper
53+
device = "cuda" if torch.cuda.is_available() else "cpu"
54+
55+
print("Loading Whisper Model!")
56+
asr_model = WhisperModel("large-v2", device=device, compute_type="float16")
57+
58+
metadata = {"audio_file": [], "text": [], "speaker_name": []}
59+
60+
if gradio_progress is not None:
61+
tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...")
62+
else:
63+
tqdm_object = tqdm(audio_files)
64+
65+
for audio_path in tqdm_object:
66+
wav, sr = torchaudio.load(audio_path)
67+
# stereo to mono if needed
68+
if wav.size(0) != 1:
69+
wav = torch.mean(wav, dim=0, keepdim=True)
70+
71+
wav = wav.squeeze()
72+
audio_total_size += (wav.size(-1) / sr)
73+
74+
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
75+
segments = list(segments)
76+
i = 0
77+
sentence = ""
78+
sentence_start = None
79+
first_word = True
80+
# added all segments words in a unique list
81+
words_list = []
82+
for _, segment in enumerate(segments):
83+
words = list(segment.words)
84+
words_list.extend(words)
85+
86+
# process each word
87+
for word_idx, word in enumerate(words_list):
88+
if first_word:
89+
sentence_start = word.start
90+
# If it is the first sentence, add buffer or get the begining of the file
91+
if word_idx == 0:
92+
sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start
93+
else:
94+
# get previous sentence end
95+
previous_word_end = words_list[word_idx - 1].end
96+
# add buffer or get the silence midle between the previous sentence and the current one
97+
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
98+
99+
sentence = word.word
100+
first_word = False
101+
else:
102+
sentence += word.word
103+
104+
if word.word[-1] in ["!", ".", "?"]:
105+
sentence = sentence[1:]
106+
# Expand number and abbreviations plus normalization
107+
sentence = multilingual_cleaners(sentence, target_language)
108+
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
109+
110+
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
111+
112+
# Check for the next word's existence
113+
if word_idx + 1 < len(words_list):
114+
next_word_start = words_list[word_idx + 1].start
115+
else:
116+
# If don't have more words it means that it is the last sentence then use the audio len as next word start
117+
next_word_start = (wav.shape[0] - 1) / sr
118+
119+
# Average the current word end and next word start
120+
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
121+
122+
absoulte_path = os.path.join(out_path, audio_file)
123+
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
124+
i += 1
125+
first_word = True
126+
127+
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
128+
# if the audio is too short ignore it (i.e < 0.33 seconds)
129+
if audio.size(-1) >= sr/3:
130+
torchaudio.save(absoulte_path,
131+
audio,
132+
sr
133+
)
134+
else:
135+
continue
136+
137+
metadata["audio_file"].append(audio_file)
138+
metadata["text"].append(sentence)
139+
metadata["speaker_name"].append(speaker_name)
140+
141+
df = pandas.DataFrame(metadata)
142+
df = df.sample(frac=1)
143+
num_val_samples = int(len(df)*eval_percentage)
144+
145+
df_eval = df[:num_val_samples]
146+
df_train = df[num_val_samples:]
147+
148+
df_train = df_train.sort_values('audio_file')
149+
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
150+
df_train.to_csv(train_metadata_path, sep="|", index=False)
151+
152+
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
153+
df_eval = df_eval.sort_values('audio_file')
154+
df_eval.to_csv(eval_metadata_path, sep="|", index=False)
155+
156+
# deallocate VRAM and RAM
157+
del asr_model, df_train, df_eval, df, metadata
158+
gc.collect()
159+
160+
return train_metadata_path, eval_metadata_path, audio_total_size
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import os
2+
import gc
3+
4+
from trainer import Trainer, TrainerArgs
5+
6+
from TTS.config.shared_configs import BaseDatasetConfig
7+
from TTS.tts.datasets import load_tts_samples
8+
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
9+
from TTS.utils.manage import ModelManager
10+
11+
12+
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995):
13+
# Logging parameters
14+
RUN_NAME = "GPT_XTTS_FT"
15+
PROJECT_NAME = "XTTS_trainer"
16+
DASHBOARD_LOGGER = "tensorboard"
17+
LOGGER_URI = None
18+
19+
# Set here the path that the checkpoints will be saved. Default: ./run/training/
20+
OUT_PATH = os.path.join(output_path, "run", "training")
21+
22+
# Training Parameters
23+
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
24+
START_WITH_EVAL = False # if True it will star with evaluation
25+
BATCH_SIZE = batch_size # set here the batch size
26+
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
27+
28+
29+
# Define here the dataset that you want to use for the fine-tuning on.
30+
config_dataset = BaseDatasetConfig(
31+
formatter="coqui",
32+
dataset_name="ft_dataset",
33+
path=os.path.dirname(train_csv),
34+
meta_file_train=train_csv,
35+
meta_file_val=eval_csv,
36+
language=language,
37+
)
38+
39+
# Add here the configs of the datasets
40+
DATASETS_CONFIG_LIST = [config_dataset]
41+
42+
# Define the path where XTTS v2.0.1 files will be downloaded
43+
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
44+
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)
45+
46+
47+
# DVAE files
48+
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
49+
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
50+
51+
# Set the path to the downloaded files
52+
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK))
53+
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK))
54+
55+
# download DVAE files if needed
56+
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
57+
print(" > Downloading DVAE files!")
58+
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)
59+
60+
61+
# Download XTTS v2.0 checkpoint if needed
62+
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
63+
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth"
64+
XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json"
65+
66+
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning.
67+
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file
68+
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file
69+
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file
70+
71+
# download XTTS v2.0 files if needed
72+
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT):
73+
print(" > Downloading XTTS v2.0 files!")
74+
ModelManager._download_model_files(
75+
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
76+
)
77+
78+
# init args and config
79+
model_args = GPTArgs(
80+
max_conditioning_length=132300, # 6 secs
81+
min_conditioning_length=66150, # 3 secs
82+
debug_loading_failures=False,
83+
max_wav_length=max_audio_length, # ~11.6 seconds
84+
max_text_length=200,
85+
mel_norm_file=MEL_NORM_FILE,
86+
dvae_checkpoint=DVAE_CHECKPOINT,
87+
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
88+
tokenizer_file=TOKENIZER_FILE,
89+
gpt_num_audio_tokens=1026,
90+
gpt_start_audio_token=1024,
91+
gpt_stop_audio_token=1025,
92+
gpt_use_masking_gt_prompt_approach=True,
93+
gpt_use_perceiver_resampler=True,
94+
)
95+
# define audio config
96+
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
97+
# training parameters config
98+
config = GPTTrainerConfig(
99+
epochs=num_epochs,
100+
output_path=OUT_PATH,
101+
model_args=model_args,
102+
run_name=RUN_NAME,
103+
project_name=PROJECT_NAME,
104+
run_description="""
105+
GPT XTTS training
106+
""",
107+
dashboard_logger=DASHBOARD_LOGGER,
108+
logger_uri=LOGGER_URI,
109+
audio=audio_config,
110+
batch_size=BATCH_SIZE,
111+
batch_group_size=48,
112+
eval_batch_size=BATCH_SIZE,
113+
num_loader_workers=8,
114+
eval_split_max_size=256,
115+
print_step=50,
116+
plot_step=100,
117+
log_model_step=100,
118+
save_step=1000,
119+
save_n_checkpoints=1,
120+
save_checkpoints=True,
121+
# target_loss="loss",
122+
print_eval=False,
123+
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
124+
optimizer="AdamW",
125+
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
126+
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
127+
lr=5e-06, # learning rate
128+
lr_scheduler="MultiStepLR",
129+
# it was adjusted accordly for the new step scheme
130+
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
131+
test_sentences=[],
132+
)
133+
134+
# init the model from config
135+
model = GPTTrainer.init_from_config(config)
136+
137+
# load training samples
138+
train_samples, eval_samples = load_tts_samples(
139+
DATASETS_CONFIG_LIST,
140+
eval_split=True,
141+
eval_split_max_size=config.eval_split_max_size,
142+
eval_split_size=config.eval_split_size,
143+
)
144+
145+
# init the trainer and 🚀
146+
trainer = Trainer(
147+
TrainerArgs(
148+
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
149+
skip_train_epoch=False,
150+
start_with_eval=START_WITH_EVAL,
151+
grad_accum_steps=GRAD_ACUMM_STEPS,
152+
),
153+
config,
154+
output_path=OUT_PATH,
155+
model=model,
156+
train_samples=train_samples,
157+
eval_samples=eval_samples,
158+
)
159+
trainer.fit()
160+
161+
# get the longest text audio file to use as speaker reference
162+
samples_len = [len(item["text"].split(" ")) for item in train_samples]
163+
longest_text_idx = samples_len.index(max(samples_len))
164+
speaker_ref = train_samples[longest_text_idx]["audio_file"]
165+
166+
trainer_out_path = trainer.output_path
167+
168+
# deallocate VRAM and RAM
169+
del model, trainer, train_samples, eval_samples
170+
gc.collect()
171+
172+
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref

0 commit comments

Comments
 (0)