|
| 1 | +import os |
| 2 | +import gc |
| 3 | + |
| 4 | +from trainer import Trainer, TrainerArgs |
| 5 | + |
| 6 | +from TTS.config.shared_configs import BaseDatasetConfig |
| 7 | +from TTS.tts.datasets import load_tts_samples |
| 8 | +from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig |
| 9 | +from TTS.utils.manage import ModelManager |
| 10 | + |
| 11 | + |
| 12 | +def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): |
| 13 | + # Logging parameters |
| 14 | + RUN_NAME = "GPT_XTTS_FT" |
| 15 | + PROJECT_NAME = "XTTS_trainer" |
| 16 | + DASHBOARD_LOGGER = "tensorboard" |
| 17 | + LOGGER_URI = None |
| 18 | + |
| 19 | + # Set here the path that the checkpoints will be saved. Default: ./run/training/ |
| 20 | + OUT_PATH = os.path.join(output_path, "run", "training") |
| 21 | + |
| 22 | + # Training Parameters |
| 23 | + OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False |
| 24 | + START_WITH_EVAL = False # if True it will star with evaluation |
| 25 | + BATCH_SIZE = batch_size # set here the batch size |
| 26 | + GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps |
| 27 | + |
| 28 | + |
| 29 | + # Define here the dataset that you want to use for the fine-tuning on. |
| 30 | + config_dataset = BaseDatasetConfig( |
| 31 | + formatter="coqui", |
| 32 | + dataset_name="ft_dataset", |
| 33 | + path=os.path.dirname(train_csv), |
| 34 | + meta_file_train=train_csv, |
| 35 | + meta_file_val=eval_csv, |
| 36 | + language=language, |
| 37 | + ) |
| 38 | + |
| 39 | + # Add here the configs of the datasets |
| 40 | + DATASETS_CONFIG_LIST = [config_dataset] |
| 41 | + |
| 42 | + # Define the path where XTTS v2.0.1 files will be downloaded |
| 43 | + CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") |
| 44 | + os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) |
| 45 | + |
| 46 | + |
| 47 | + # DVAE files |
| 48 | + DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" |
| 49 | + MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" |
| 50 | + |
| 51 | + # Set the path to the downloaded files |
| 52 | + DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) |
| 53 | + MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) |
| 54 | + |
| 55 | + # download DVAE files if needed |
| 56 | + if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): |
| 57 | + print(" > Downloading DVAE files!") |
| 58 | + ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) |
| 59 | + |
| 60 | + |
| 61 | + # Download XTTS v2.0 checkpoint if needed |
| 62 | + TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" |
| 63 | + XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" |
| 64 | + XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" |
| 65 | + |
| 66 | + # XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. |
| 67 | + TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file |
| 68 | + XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file |
| 69 | + XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file |
| 70 | + |
| 71 | + # download XTTS v2.0 files if needed |
| 72 | + if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): |
| 73 | + print(" > Downloading XTTS v2.0 files!") |
| 74 | + ModelManager._download_model_files( |
| 75 | + [TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True |
| 76 | + ) |
| 77 | + |
| 78 | + # init args and config |
| 79 | + model_args = GPTArgs( |
| 80 | + max_conditioning_length=132300, # 6 secs |
| 81 | + min_conditioning_length=66150, # 3 secs |
| 82 | + debug_loading_failures=False, |
| 83 | + max_wav_length=max_audio_length, # ~11.6 seconds |
| 84 | + max_text_length=200, |
| 85 | + mel_norm_file=MEL_NORM_FILE, |
| 86 | + dvae_checkpoint=DVAE_CHECKPOINT, |
| 87 | + xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune |
| 88 | + tokenizer_file=TOKENIZER_FILE, |
| 89 | + gpt_num_audio_tokens=1026, |
| 90 | + gpt_start_audio_token=1024, |
| 91 | + gpt_stop_audio_token=1025, |
| 92 | + gpt_use_masking_gt_prompt_approach=True, |
| 93 | + gpt_use_perceiver_resampler=True, |
| 94 | + ) |
| 95 | + # define audio config |
| 96 | + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) |
| 97 | + # training parameters config |
| 98 | + config = GPTTrainerConfig( |
| 99 | + epochs=num_epochs, |
| 100 | + output_path=OUT_PATH, |
| 101 | + model_args=model_args, |
| 102 | + run_name=RUN_NAME, |
| 103 | + project_name=PROJECT_NAME, |
| 104 | + run_description=""" |
| 105 | + GPT XTTS training |
| 106 | + """, |
| 107 | + dashboard_logger=DASHBOARD_LOGGER, |
| 108 | + logger_uri=LOGGER_URI, |
| 109 | + audio=audio_config, |
| 110 | + batch_size=BATCH_SIZE, |
| 111 | + batch_group_size=48, |
| 112 | + eval_batch_size=BATCH_SIZE, |
| 113 | + num_loader_workers=8, |
| 114 | + eval_split_max_size=256, |
| 115 | + print_step=50, |
| 116 | + plot_step=100, |
| 117 | + log_model_step=100, |
| 118 | + save_step=1000, |
| 119 | + save_n_checkpoints=1, |
| 120 | + save_checkpoints=True, |
| 121 | + # target_loss="loss", |
| 122 | + print_eval=False, |
| 123 | + # Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. |
| 124 | + optimizer="AdamW", |
| 125 | + optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, |
| 126 | + optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, |
| 127 | + lr=5e-06, # learning rate |
| 128 | + lr_scheduler="MultiStepLR", |
| 129 | + # it was adjusted accordly for the new step scheme |
| 130 | + lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, |
| 131 | + test_sentences=[], |
| 132 | + ) |
| 133 | + |
| 134 | + # init the model from config |
| 135 | + model = GPTTrainer.init_from_config(config) |
| 136 | + |
| 137 | + # load training samples |
| 138 | + train_samples, eval_samples = load_tts_samples( |
| 139 | + DATASETS_CONFIG_LIST, |
| 140 | + eval_split=True, |
| 141 | + eval_split_max_size=config.eval_split_max_size, |
| 142 | + eval_split_size=config.eval_split_size, |
| 143 | + ) |
| 144 | + |
| 145 | + # init the trainer and 🚀 |
| 146 | + trainer = Trainer( |
| 147 | + TrainerArgs( |
| 148 | + restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter |
| 149 | + skip_train_epoch=False, |
| 150 | + start_with_eval=START_WITH_EVAL, |
| 151 | + grad_accum_steps=GRAD_ACUMM_STEPS, |
| 152 | + ), |
| 153 | + config, |
| 154 | + output_path=OUT_PATH, |
| 155 | + model=model, |
| 156 | + train_samples=train_samples, |
| 157 | + eval_samples=eval_samples, |
| 158 | + ) |
| 159 | + trainer.fit() |
| 160 | + |
| 161 | + # get the longest text audio file to use as speaker reference |
| 162 | + samples_len = [len(item["text"].split(" ")) for item in train_samples] |
| 163 | + longest_text_idx = samples_len.index(max(samples_len)) |
| 164 | + speaker_ref = train_samples[longest_text_idx]["audio_file"] |
| 165 | + |
| 166 | + trainer_out_path = trainer.output_path |
| 167 | + |
| 168 | + # deallocate VRAM and RAM |
| 169 | + del model, trainer, train_samples, eval_samples |
| 170 | + gc.collect() |
| 171 | + |
| 172 | + return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref |
0 commit comments