Skip to content

Commit 8f5ef3a

Browse files
authored
Update training guide (#102)
* Update README.md * Update README.md * Update README.md * update configs and readme * fix training and eval single gpus and long audios errors * fix error transcriptions none * fix trascription null wer --------- Co-authored-by: [email protected] <Yoach Lacombe>
1 parent 9f34c1b commit 8f5ef3a

File tree

10 files changed

+285
-59
lines changed

10 files changed

+285
-59
lines changed

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ We've set up an [inference guide](INFERENCE.md) to make generation faster. Think
118118
https://github.com/huggingface/parler-tts/assets/52246514/251e2488-fe6e-42c1-81cd-814c5b7795b0
119119

120120
## Training
121-
> [!WARNING]
122-
> The training guide has yet to be adapted to the newest checkpoints.
123121

124122
<a target="_blank" href="https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb">
125123
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
@@ -131,12 +129,15 @@ The [training folder](/training/) contains all the information to train or fine-
131129
- [3. A training guide](/training/README.md#3-training)
132130

133131
> [!IMPORTANT]
134-
> **TL;DR:** After having followed the [installation steps](/training/README.md#requirements), you can reproduce the Parler-TTS Mini v0.1 training recipe with the following command line:
132+
> **TL;DR:** After having followed the [installation steps](/training/README.md#requirements), you can reproduce the Parler-TTS Mini v1 training recipe with the following command line:
135133
136134
```sh
137-
accelerate launch ./training/run_parler_tts_training.py ./helpers/training_configs/starting_point_0.01.json
135+
accelerate launch ./training/run_parler_tts_training.py ./helpers/training_configs/starting_point_v1.json
138136
```
139137

138+
> [!IMPORTANT]
139+
> You can also follow [this fine-tuning guide](https://colab.research.google.com/github/ylacombe/scripts_and_notebooks/blob/main/Finetuning_Parler_TTS_on_a_single_speaker_dataset.ipynb) on a mono-speaker dataset example.
140+
140141
## Acknowledgements
141142

142143
This library builds on top of a number of open-source giants, to whom we'd like to extend our warmest thanks for providing these tools!

helpers/model_init_scripts/init_dummy_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
# set other default generation config params
6262
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
6363
model.generation_config.do_sample = True # True
64-
model.generation_config.guidance_scale = 1 # 3.0
64+
6565

6666
model.config.pad_token_id = encodec_vocab_size
6767
model.config.decoder_start_token_id = encodec_vocab_size + 1

helpers/model_init_scripts/init_dummy_model_with_encodec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
# set other default generation config params
6060
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
6161
model.generation_config.do_sample = True # True
62-
model.generation_config.guidance_scale = 1 # 3.0
62+
6363

6464
model.config.pad_token_id = encodec_vocab_size
6565
model.config.decoder_start_token_id = encodec_vocab_size + 1
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from parler_tts import ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration, ParlerTTSDecoderConfig
2+
from transformers import AutoConfig
3+
import os
4+
import argparse
5+
6+
7+
if __name__ == "__main__":
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.")
10+
parser.add_argument("--text_model", type=str, help="Repository id or path to the text encoder.")
11+
parser.add_argument("--audio_model", type=str, help="Repository id or path to the audio encoder.")
12+
13+
args = parser.parse_args()
14+
15+
text_model = args.text_model
16+
encodec_version = args.audio_model
17+
18+
t5 = AutoConfig.from_pretrained(text_model)
19+
encodec = AutoConfig.from_pretrained(encodec_version)
20+
21+
encodec_vocab_size = encodec.codebook_size
22+
num_codebooks = encodec.num_codebooks
23+
print("num_codebooks", num_codebooks)
24+
25+
decoder_config = ParlerTTSDecoderConfig(
26+
vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64
27+
max_position_embeddings=4096, # 30 s = 2580
28+
num_hidden_layers=30,
29+
ffn_dim=6144,
30+
num_attention_heads=24,
31+
num_key_value_heads=24,
32+
layerdrop=0.0,
33+
use_cache=True,
34+
activation_function="gelu",
35+
hidden_size=1536,
36+
dropout=0.1,
37+
attention_dropout=0.0,
38+
activation_dropout=0.0,
39+
pad_token_id=encodec_vocab_size,
40+
eos_token_id=encodec_vocab_size,
41+
bos_token_id=encodec_vocab_size + 1,
42+
num_codebooks=num_codebooks,
43+
)
44+
45+
decoder = ParlerTTSForCausalLM(decoder_config)
46+
decoder.save_pretrained(os.path.join(args.save_directory, "decoder"))
47+
48+
model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained(
49+
text_encoder_pretrained_model_name_or_path=text_model,
50+
audio_encoder_pretrained_model_name_or_path=encodec_version,
51+
decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"),
52+
vocab_size=t5.vocab_size,
53+
)
54+
55+
# set the appropriate bos/pad token ids
56+
model.generation_config.decoder_start_token_id = encodec_vocab_size + 1
57+
model.generation_config.pad_token_id = encodec_vocab_size
58+
model.generation_config.eos_token_id = encodec_vocab_size
59+
60+
# set other default generation config params
61+
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
62+
model.generation_config.do_sample = True # True
63+
64+
65+
model.config.pad_token_id = encodec_vocab_size
66+
model.config.decoder_start_token_id = encodec_vocab_size + 1
67+
68+
model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-larger/"))

helpers/model_init_scripts/init_model_600M.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
# set other default generation config params
6262
model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate)
6363
model.generation_config.do_sample = True # True
64-
model.generation_config.guidance_scale = 1 # 3.0
6564

6665
model.config.pad_token_id = encodec_vocab_size
6766
model.config.decoder_start_token_id = encodec_vocab_size + 1
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{
2+
"model_name_or_path": "./parler-tts-untrained-600M/parler-tts-untrained-600M/",
3+
"save_to_disk": "./tmp_dataset_audio/",
4+
"temporary_save_to_disk": "./audio_code_tmp/",
5+
"wandb_project": "parler-tts-50k-hours",
6+
"wandb_run_name": "Mini",
7+
8+
"feature_extractor_name":"ylacombe/dac_44khZ_8kbps",
9+
"description_tokenizer_name":"google/flan-t5-large",
10+
"prompt_tokenizer_name":"google/flan-t5-large",
11+
12+
"report_to": ["wandb"],
13+
"overwrite_output_dir": true,
14+
"output_dir": "./output_dir_training",
15+
16+
"train_dataset_name": "ylacombe/libritts_r_filtered+ylacombe/libritts_r_filtered+ylacombe/libritts_r_filtered+parler-tts/mls_eng",
17+
"train_metadata_dataset_name": "ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/mls-eng-descriptions-v4",
18+
"train_dataset_config_name": "clean+clean+other+default",
19+
"train_split_name": "train.clean.360+train.clean.100+train.other.500+train",
20+
21+
"eval_dataset_name": "ylacombe/libritts_r_filtered+parler-tts/mls_eng",
22+
"eval_metadata_dataset_name": "ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/mls-eng-descriptions-v4",
23+
"eval_dataset_config_name": "other+default",
24+
"eval_split_name": "test.other+test",
25+
26+
"target_audio_column_name": "audio",
27+
"description_column_name": "text_description",
28+
"prompt_column_name": "text",
29+
30+
"max_eval_samples": 96,
31+
32+
"max_duration_in_seconds": 30,
33+
"min_duration_in_seconds": 2.0,
34+
"max_text_length": 600,
35+
36+
"group_by_length": true,
37+
38+
"add_audio_samples_to_wandb": true,
39+
"id_column_name": "id",
40+
41+
"preprocessing_num_workers": 8,
42+
43+
"do_train": true,
44+
"num_train_epochs": 4,
45+
"gradient_accumulation_steps": 4,
46+
"gradient_checkpointing": false,
47+
"per_device_train_batch_size": 6,
48+
"learning_rate": 0.00095,
49+
"adam_beta1": 0.9,
50+
"adam_beta2": 0.99,
51+
"weight_decay": 0.01,
52+
53+
"lr_scheduler_type": "constant_with_warmup",
54+
"warmup_steps": 20000,
55+
56+
57+
"logging_steps": 1000,
58+
"freeze_text_encoder": true,
59+
60+
61+
"do_eval": true,
62+
"predict_with_generate": true,
63+
"include_inputs_for_metrics": true,
64+
"evaluation_strategy": "steps",
65+
"eval_steps": 10000,
66+
"save_steps": 10000,
67+
68+
"per_device_eval_batch_size": 4,
69+
70+
"audio_encoder_per_device_batch_size":24,
71+
"dtype": "bfloat16",
72+
"seed": 456,
73+
74+
"dataloader_num_workers":8,
75+
"attn_implementation": "sdpa"
76+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{
2+
"model_name_or_path": "./parler-tts-untrained-large/parler-tts-untrained-large",
3+
"save_to_disk": "./tmp_dataset_audio/",
4+
"temporary_save_to_disk": "./audio_code_tmp/",
5+
"wandb_project": "parler-tts-50k-hours",
6+
"wandb_run_name": "Large",
7+
8+
"feature_extractor_name":"ylacombe/dac_44khZ_8kbps",
9+
"description_tokenizer_name":"google/flan-t5-large",
10+
"prompt_tokenizer_name":"google/flan-t5-large",
11+
12+
"report_to": ["wandb"],
13+
"overwrite_output_dir": true,
14+
"output_dir": "./output_dir_training",
15+
16+
"train_dataset_name": "ylacombe/libritts_r_filtered+ylacombe/libritts_r_filtered+ylacombe/libritts_r_filtered+parler-tts/mls_eng",
17+
"train_metadata_dataset_name": "ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/mls-eng-descriptions-v4",
18+
"train_dataset_config_name": "clean+clean+other+default",
19+
"train_split_name": "train.clean.360+train.clean.100+train.other.500+train",
20+
21+
"eval_dataset_name": "ylacombe/libritts_r_filtered+parler-tts/mls_eng",
22+
"eval_metadata_dataset_name": "ylacombe/libritts-r-filtered-descriptions-10k-v5-without-accents+ylacombe/mls-eng-descriptions-v4",
23+
"eval_dataset_config_name": "other+default",
24+
"eval_split_name": "test.other+test",
25+
26+
"target_audio_column_name": "audio",
27+
"description_column_name": "text_description",
28+
"prompt_column_name": "text",
29+
30+
"max_eval_samples": 96,
31+
32+
"max_duration_in_seconds": 30,
33+
"min_duration_in_seconds": 2.0,
34+
"max_text_length": 600,
35+
36+
"group_by_length": true,
37+
38+
"add_audio_samples_to_wandb": true,
39+
"id_column_name": "id",
40+
41+
"preprocessing_num_workers": 8,
42+
43+
"do_train": true,
44+
"num_train_epochs": 4,
45+
"gradient_accumulation_steps": 4,
46+
"gradient_checkpointing": false,
47+
"per_device_train_batch_size": 3,
48+
"learning_rate": 0.0015,
49+
"adam_beta1": 0.9,
50+
"adam_beta2": 0.99,
51+
"weight_decay": 0.01,
52+
53+
"lr_scheduler_type": "constant_with_warmup",
54+
"warmup_steps": 10000,
55+
56+
57+
"logging_steps": 1000,
58+
"freeze_text_encoder": true,
59+
60+
61+
"do_eval": true,
62+
"predict_with_generate": true,
63+
"include_inputs_for_metrics": true,
64+
"evaluation_strategy": "steps",
65+
"eval_steps": 10000,
66+
"save_steps": 10000,
67+
"save_total_limit": 10,
68+
69+
"per_device_eval_batch_size": 6,
70+
71+
"audio_encoder_per_device_batch_size":24,
72+
"dtype": "bfloat16",
73+
"seed": 738,
74+
75+
"dataloader_num_workers":8,
76+
"attn_implementation": "sdpa"
77+
}

0 commit comments

Comments
 (0)