Skip to content

Commit 1e4a650

Browse files
authored
Merge pull request #144 from lmxue/valle_resume
Fix bug for VALLE resume
2 parents 6892d03 + bd3ce83 commit 1e4a650

File tree

6 files changed

+162
-38
lines changed

6 files changed

+162
-38
lines changed

config/valle.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
// "scaling_xformers": false, // Apply Reworked Conformer scaling on Transformers
3737
},
3838
"train": {
39+
"use_dynamic_batchsize": false, // If use dynamic batch size
3940
"ddp": false,
4041
"train_stage": 1, // 0: train all modules, For VALL_E, support 1: AR Decoder 2: NAR Decoder(s)
4142
"max_epoch": 20,

egs/tts/VALLE/README.md

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ There are four stages in total:
1717
## 1. Data Preparation
1818
1919
### Dataset Download
20-
You can use the commonly used TTS dataset to train VALL-E model, e.g., LibriTTS, etc. We strongly recommend you use LibriTTS to train VALL-E model for the first time. How to download dataset is detailed [here](../../datasets/README.md).
20+
You can use the commonly used TTS dataset to train the VALL-E model, e.g., LibriTTS, etc. We strongly recommend you use LibriTTS to train the VALL-E model for the first time. How to download the dataset is detailed [here](../../datasets/README.md).
2121
2222
### Configuration
2323
@@ -51,7 +51,7 @@ Specify the `processed_dir` and the `log_dir` and for saving the processed data
5151

5252
### Run
5353

54-
Run the `run.sh` as the preproces stage (set `--stage 1`):
54+
Run the `run.sh` as the preprocess stage (set `--stage 1`):
5555

5656
```bash
5757
sh egs/tts/VALLE/run.sh --stage 1
@@ -64,22 +64,22 @@ sh egs/tts/VALLE/run.sh --stage 1
6464

6565
### Configuration
6666

67-
We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
67+
We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
6868

69-
```
69+
```json
7070
"train": {
7171
"batch_size": 4,
7272
}
7373
```
7474

75-
### Run
75+
### Train From Scratch
7676

77-
Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
77+
Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
7878

79-
Specifically, VALL-E need to train a autoregressive (AR) model and then a non-autoregressive (NAR) model. So, you can set `--model_train_stage 1` to train AR model, and set `--model_train_stage 2` to train NAR model, where `--ar_model_ckpt_dir` should be set as the ckeckpoint path to the trained AR model.
79+
Specifically, VALL-E needs to train an autoregressive (AR) model and then a non-autoregressive (NAR) model. So, you can set `--model_train_stage 1` to train AR model, and set `--model_train_stage 2` to train NAR model, where `--ar_model_ckpt_dir` should be set as the checkpoint path to the trained AR model.
8080

8181

82-
Train a AR moel, just run:
82+
Train an AR model, just run:
8383

8484
```bash
8585
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName]
@@ -89,7 +89,74 @@ Train a NAR model, just run:
8989
```bash
9090
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName]
9191
```
92-
<!-- > **NOTE:** To train a NAR model, `--checkpoint_path` should be set as the ckeckpoint path to the trained AR model. -->
92+
<!-- > **NOTE:** To train a NAR model, `--checkpoint_path` should be set as the checkpoint path to the trained AR model. -->
93+
94+
95+
### Train From Existing Source
96+
97+
We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
98+
99+
By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`,
100+
101+
Train an AR model, just run:
102+
103+
```bash
104+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
105+
--resume true
106+
```
107+
108+
Train a NAR model, just run:
109+
```bash
110+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
111+
--resume true
112+
```
113+
114+
115+
116+
You can also choose a **specific checkpoint** for retraining by `--resume_from_ckpt_path` argument. For example, if you want to resume training from the checkpoint `Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]`,
117+
118+
Train an AR model, just run:
119+
120+
```bash
121+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
122+
--resume true \
123+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]"
124+
```
125+
126+
Train a NAR model, just run:
127+
```bash
128+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
129+
--resume true \
130+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]"
131+
```
132+
133+
134+
If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`,
135+
136+
137+
Train an AR model, just run:
138+
139+
```bash
140+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 1 --name [YourExptName] \
141+
--resume true \
142+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificARCheckpoint]" \
143+
--resume_type "finetune"
144+
```
145+
146+
Train a NAR model, just run:
147+
```bash
148+
sh egs/tts/VALLE/run.sh --stage 2 --model_train_stage 2 --ar_model_ckpt_dir [ARModelPath] --name [YourExptName] \
149+
--resume true \
150+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificNARCheckpoint]" \
151+
--resume_type "finetune"
152+
```
153+
154+
> **NOTE:** The `--resume_type` is set as `"resume"` in default. It's not necessary to specify it when resuming training.
155+
>
156+
> The difference between `"resume"` and `"finetune"` is that the `"finetune"` will **only** load the pretrained model weights from the checkpoint, while the `"resume"` will load all the training states (including optimizer, scheduler, etc.) from the checkpoint.
157+
158+
159+
93160

94161
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
95162
@@ -127,8 +194,8 @@ sh egs/tts/VALLE/run.sh --stage 3 --gpu "0" \
127194
```
128195

129196
We have released pre-trained VALL-E models, so you can download the pre-trained model and then generate speech following the above inference instruction. Specifically,
130-
1. The pre-trained VALL-E trained on [LibriTTS](https://github.com/open-mmlab/Amphion/tree/main/egs/datasets#libritts) can be download [here](https://huggingface.co/amphion/valle-libritts).
131-
2. The pre-trained VALL-E trained on a part of [Libri-light](https://ai.meta.com/tools/libri-light/) (about 6k hours) can be download [here](https://huggingface.co/amphion/valle_librilight_6k).
197+
1. The pre-trained VALL-E trained on [LibriTTS](https://github.com/open-mmlab/Amphion/tree/main/egs/datasets#libritts) can be downloaded [here](https://huggingface.co/amphion/valle-libritts).
198+
2. The pre-trained VALL-E trained on the part of [Libri-light](https://ai.meta.com/tools/libri-light/) (about 6k hours) can be downloaded [here](https://huggingface.co/amphion/valle_librilight_6k).
132199

133200
```bibtex
134201
@article{wang2023neural,

egs/tts/VALLE/run.sh

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ python setup.py build_ext --inplace
1717
cd $work_dir
1818

1919
######## Parse the Given Parameters from the Commond ###########
20-
options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,ar_model_ckpt_dir:,infer_output_dir:,infer_mode:,infer_test_list_file:,infer_text:,infer_text_prompt:,infer_audio_prompt:,model_train_stage:,name:,stage: -- "$@")
20+
options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,ar_model_ckpt_dir:,infer_output_dir:,infer_mode:,infer_test_list_file:,infer_text:,infer_text_prompt:,infer_audio_prompt:,model_train_stage:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type: -- "$@")
2121
eval set -- "$options"
2222

2323
while true; do
@@ -52,6 +52,13 @@ while true; do
5252
# [Only for Inference] The inference audio prompt. It is only used when the inference model is "single".
5353
--infer_audio_prompt) shift; infer_audio_prompt=$1 ; shift ;;
5454

55+
# [Only for Training] Resume configuration
56+
--resume) shift; resume=$1 ; shift ;;
57+
# [Only for Training] The specific checkpoint path that you want to resume from.
58+
--resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;;
59+
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
60+
--resume_type) shift; resume_type=$1 ; shift ;;
61+
5562
--) shift ; break ;;
5663
*) echo "Invalid option: $1" exit 1 ;;
5764
esac
@@ -98,13 +105,38 @@ if [ $running_stage -eq 2 ]; then
98105

99106
echo "Exprimental Name: $exp_name"
100107

101-
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 29510 \
102-
"${work_dir}"/bins/tts/train.py \
103-
--config $exp_config \
104-
--exp_name $exp_name \
105-
--log_level debug \
106-
--train_stage $model_train_stage \
107-
--checkpoint_path $ar_model_ckpt_dir
108+
# Add default value
109+
if [ -z "$resume_from_ckpt_path" ]; then
110+
resume_from_ckpt_path=""
111+
fi
112+
113+
if [ -z "$resume_type" ]; then
114+
resume_type="resume"
115+
fi
116+
117+
118+
if [ "$resume" = true ]; then
119+
echo "Resume from the existing experiment..."
120+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 29510 \
121+
"${work_dir}"/bins/tts/train.py \
122+
--config $exp_config \
123+
--exp_name $exp_name \
124+
--log_level debug \
125+
--train_stage $model_train_stage \
126+
--ar_model_ckpt_dir $ar_model_ckpt_dir \
127+
--resume \
128+
--checkpoint_path "$resume_from_ckpt_path" \
129+
--resume_type "$resume_type"
130+
else
131+
echo "Start a new experiment..."
132+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch --main_process_port 29510 \
133+
"${work_dir}"/bins/tts/train.py \
134+
--config $exp_config \
135+
--exp_name $exp_name \
136+
--log_level debug \
137+
--train_stage $model_train_stage \
138+
--ar_model_ckpt_dir $ar_model_ckpt_dir
139+
fi
108140
fi
109141

110142

egs/tts/VITS/README.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/Text-to-Speech)
44
[![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/Text-to-Speech)
55

6-
In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning.
6+
In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes a conditional variational autoencoder with adversarial learning.
77

88
There are four stages in total:
99

@@ -20,7 +20,7 @@ There are four stages in total:
2020
## 1. Data Preparation
2121
2222
### Dataset Download
23-
You can use the commonly used TTS dataset to train TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train single-speaker TTS model for the first time. While for training multi-speaker TTS model for the first time, we would recommend using Hi-Fi TTS. The process of downloading dataset has been detailed [here](../../datasets/README.md).
23+
You can use the commonly used TTS dataset to train the TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train the single-speaker TTS model for the first time. While training the multi-speaker TTS model for the first time, we recommend using Hi-Fi TTS. The process of downloading the dataset has been detailed [here](../../datasets/README.md).
2424
2525
### Configuration
2626
@@ -75,7 +75,7 @@ sh egs/tts/VITS/run.sh --stage 1
7575
7676
### Configuration
7777
78-
We provide the default hyparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
78+
We provide the default hyperparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
7979
For training the multi-speaker TTS model, specify the `n_speakers` value to be greater (used for new speaker fine-tuning) than or equal to the number of speakers in your dataset(s) and set `multi_speaker_training` to `true`.
8080
8181
```json
@@ -98,9 +98,9 @@ sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName]
9898
9999
### Train From Existing Source
100100
101-
We support training from existing source for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
101+
We support training from existing sources for various purposes. You can resume training the model from a checkpoint or fine-tune a model from another checkpoint.
102102
103-
Setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`, run:
103+
By setting `--resume true`, the training will resume from the **latest checkpoint** from the current `[YourExptName]` by default. For example, if you want to resume training from the latest checkpoint in `Amphion/ckpts/tts/[YourExptName]/checkpoint`, run:
104104
105105
```bash
106106
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
@@ -111,16 +111,16 @@ You can also choose a **specific checkpoint** for retraining by `--resume_from_c
111111
112112
```bash
113113
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
114-
--resume true
115-
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
114+
--resume true \
115+
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]"
116116
```
117117
118118
If you want to **fine-tune from another checkpoint**, just use `--resume_type` and set it to `"finetune"`. For example, If you want to fine-tune the model from the checkpoint `Amphion/ckpts/tts/[AnotherExperiment]/checkpoint/[SpecificCheckpoint]`, run:
119119
120120
121121
```bash
122122
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName] \
123-
--resume true
123+
--resume true \
124124
--resume_from_ckpt_path "Amphion/ckpts/tts/[YourExptName]/checkpoint/[SpecificCheckpoint]" \
125125
--resume_type "finetune"
126126
```
@@ -206,6 +206,10 @@ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
206206
--infer_testing_set "test"
207207
```
208208
209+
210+
We released a pre-trained Amphion VITS model trained on LJSpeech. So, you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instructions. Meanwhile, the pre-trained multi-speaker VITS model trained on Hi-Fi TTS will be released soon. Stay tuned.
211+
212+
209213
```bibtex
210214
@inproceedings{kim2021conditional,
211215
title={Conditional variational autoencoder with adversarial learning for end-to-end text-to-speech},

models/tts/base/tts_trainer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,27 @@ def _check_resume(self):
163163
if self.args.resume or (
164164
self.cfg.model_type == "VALLE" and self.args.train_stage == 2
165165
):
166+
checkpoint_dir = self.checkpoint_dir
166167
if self.cfg.model_type == "VALLE" and self.args.train_stage == 2:
167-
self.args.resume_type = "finetune"
168+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
169+
if (
170+
self.args.checkpoint_path is None or len(ls) == 0
171+
): # Train stage 2 from scratch using the checkpoint of stage 1
172+
assert (
173+
self.args.ar_model_ckpt_dir is not None
174+
), "Error: ar_model_ckpt_dir should be set to train nar model."
175+
self.args.resume_type = "finetune"
176+
checkpoint_dir = self.args.ar_model_ckpt_dir
177+
self.logger.info(
178+
f"Training NAR model at stage 2 using the checkpoint of AR model at stage 1."
179+
)
168180

169-
self.logger.info("Resuming from checkpoint...")
181+
self.logger.info(f"Resuming from checkpoint: {checkpoint_dir}")
170182
start = time.monotonic_ns()
171183
self.ckpt_path = self._load_model(
172-
self.checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
184+
checkpoint_dir, self.args.checkpoint_path, self.args.resume_type
173185
)
186+
self.logger.info(f"Checkpoint path: {self.ckpt_path}")
174187
end = time.monotonic_ns()
175188
self.logger.info(
176189
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
@@ -700,6 +713,7 @@ def _save_phone_symbols_file_to_exp_path(self):
700713
self.exp_dir, self.cfg.preprocess.symbols_dict
701714
)
702715
shutil.copy(phone_symbols_file, phone_symbols_file_to_exp_path)
716+
os.chmod(phone_symbols_file_to_exp_path, 0o666)
703717
print(
704718
"phone symbols been dumped to {}".format(
705719
os.path.join(self.exp_dir, self.cfg.preprocess.symbols_dict)

models/tts/valle/valle_trainer.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,6 @@ def _valid_step(self, batch):
262262

263263
return total_loss, valid_losses, valid_stats
264264

265-
def add_arguments(parser: argparse.ArgumentParser):
266-
parser.add_argument(
267-
"--train_stage",
268-
type=int,
269-
default="1",
270-
help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
271-
)
272-
273265
def _build_dataloader(self):
274266
if not self.cfg.train.use_dynamic_batchsize:
275267
return super()._build_dataloader()
@@ -359,3 +351,17 @@ def _accelerator_prepare(self):
359351
self.scheduler[key] = self.accelerator.prepare(self.scheduler[key])
360352
else:
361353
self.scheduler = self.accelerator.prepare(self.scheduler)
354+
355+
def add_arguments(parser: argparse.ArgumentParser):
356+
parser.add_argument(
357+
"--train_stage",
358+
type=int,
359+
default="1",
360+
help="0: train all modules, 1: AR Decoder, 2: NAR Decoder",
361+
)
362+
parser.add_argument(
363+
"--ar_model_ckpt_dir",
364+
type=str,
365+
default=None,
366+
help="Checkpoint for ar model ckeckpoint in the first training stage.",
367+
)

0 commit comments

Comments
 (0)