Skip to content

Commit f89f346

Browse files
committed
🔀 Finalize parallel_wavegan tensorflow support.
1 parent 14584b3 commit f89f346

File tree

3 files changed

+217
-6
lines changed

3 files changed

+217
-6
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram
2+
Based on the script [`train_parallel_wavegan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/parallel_wavegan/train_parallel_wavegan.py).
3+
4+
5+
## Convert pretrained weight from Pytorch Parallel WaveGAN to TensorFlow Parallel WaveGAN to Accelerate Inference Speed and Deployability
6+
7+
We recommand users use pytorch Parallel WaveGAN from [ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) to training for convenient and very stable. After finish training, you can convert the pytorch's weight to this tensorflow pwgan version to accelerate inference speech and enhance deployability. You can use the pretrained weight from [here](https://github.com/kan-bayashi/ParallelWaveGAN#results) then use [convert_pwgan_from_pytorch_to_tensorflow](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/parallel_wavegan/convert_pwgan_from_pytorch_to_tensorflow.ipynp) notebook to convert it. Note that the pwgan pretrained weight from pytorch repo can be use as vocoder with our text2mel model because they uses the same preprocessing procedure (for example on ljspeech dataset). In case you want training pwgan with tensorflow, let take a look below instruction, it's not fully testing yet, we tried to train around 150k steps and everything is fine.
8+
9+
## Training Parallel WaveGAN from scratch with LJSpeech dataset.
10+
This example code show you how to train Parallel WaveGAN from scratch with Tensorflow 2 based on custom training loop and tf.function. The data used for this example is LJSpeech, you can download the dataset at [link](https://keithito.com/LJ-Speech-Dataset/).
11+
12+
### Step 1: Create Tensorflow based Dataloader (tf.dataset)
13+
Please see detail at [examples/melgan/](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/melgan#step-1-create-tensorflow-based-dataloader-tfdataset)
14+
15+
### Step 2: Training from scratch
16+
After you re-define your dataloader, pls modify an input arguments, train_dataset and valid_dataset from [`train_parallel_wavegan.py`](https://github.com/tensorspeech/TensorFlowTTS/tree/master/examples/parallel_wavegan/train_parallel_wavegan.py). Here is an example command line to training Parallel WaveGAN from scratch:
17+
18+
First, you need training generator 100K steps with only stft loss:
19+
20+
```bash
21+
CUDA_VISIBLE_DEVICES=0 python examples/parallel_wavegan/train_parallel_wavegan.py \
22+
--train-dir ./dump/train/ \
23+
--dev-dir ./dump/valid/ \
24+
--outdir ./examples/parallel_wavegan/exp/train.parallel_wavegan.v1/ \
25+
--config ./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml \
26+
--use-norm 1 \
27+
--generator_mixed_precision 1 \
28+
--resume ""
29+
```
30+
31+
Then resume and start training generator + discriminator:
32+
33+
```bash
34+
CUDA_VISIBLE_DEVICES=0 python examples/parallel_wavegan/parallel_wavegan.py \
35+
--train-dir ./dump/train/ \
36+
--dev-dir ./dump/valid/ \
37+
--outdir ./examples/parallel_wavegan/exp/train.parallel_wavegan.v1/ \
38+
--config ./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml \
39+
--use-norm 1 \
40+
--resume ./examples/parallel_wavegan/exp/train.parallel_wavegan.v1/checkpoints/ckpt-100000
41+
```
42+
43+
IF you want to use MultiGPU to training you can replace `CUDA_VISIBLE_DEVICES=0` by `CUDA_VISIBLE_DEVICES=0,1,2,3` for example. You also need to tune the `batch_size` for each GPU (in config file) by yourself to maximize the performance. Note that MultiGPU now support for Training but not yet support for Decode.
44+
45+
In case you want to resume the training progress, please following below example command line:
46+
47+
```bash
48+
--resume ./examples/parallel_wavegan/exp/train.parallel_wavegan.v1/checkpoints/ckpt-100000
49+
```
50+
51+
### Step 3: Decode audio from folder mel-spectrogram
52+
To running inference on folder mel-spectrogram (eg valid folder), run below command line:
53+
54+
```bash
55+
CUDA_VISIBLE_DEVICES=0 python examples/parallel_wavegan/decode_parallel_wavegan.py \
56+
--rootdir ./dump/valid/ \
57+
--outdir ./prediction/parallel_wavegan.v1/ \
58+
--checkpoint ./examples/parallel_wavegan/exp/train.parallel_wavegan.v1/checkpoints/generator-400000.h5 \
59+
--config ./examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml \
60+
--batch-size 32 \
61+
--use-norm 1
62+
```
63+
64+
## Finetune Parallel WaveGAN with ljspeech pretrained on other languages
65+
Just load pretrained model and training from scratch with other languages. **DO NOT FORGET** re-preprocessing on your dataset if needed. A hop_size should be 256 if you want to use our pretrained.
66+
67+
68+
## Reference
69+
70+
1. https://github.com/kan-bayashi/ParallelWaveGAN
71+
2. [Parallel WaveGAN: A fast waveform generation model based on generative adversarial networks with multi-resolution spectrogram](https://arxiv.org/abs/1910.11480)

examples/parallel_wavegan/conf/parallel_wavegan.v1.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ lambda_adv: 4.0 # Loss balancing coefficient.
6565
###########################################################
6666
# DATA LOADER SETTING #
6767
###########################################################
68-
batch_size: 8 # Batch size.
69-
batch_max_steps: 16384 # Length of each audio in batch for training. Make sure dividable by hop_size.
68+
batch_size: 6 # Batch size.
69+
batch_max_steps: 25600 # Length of each audio in batch for training. Make sure dividable by hop_size.
7070
batch_max_steps_valid: 81920 # Length of each audio for validation. Make sure dividable by hope_size.
7171
remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps.
7272
allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory.
@@ -93,10 +93,10 @@ discriminator_optimizer_params:
9393
###########################################################
9494
# INTERVAL SETTING #
9595
###########################################################
96-
discriminator_train_start_steps: 0 # steps begin training discriminator
97-
train_max_steps: 4000000 # Number of training steps.
98-
save_interval_steps: 20000 # Interval steps to save checkpoint.
99-
eval_interval_steps: 5000 # Interval steps to evaluate the network.
96+
discriminator_train_start_steps: 100000 # steps begin training discriminator
97+
train_max_steps: 400000 # Number of training steps.
98+
save_interval_steps: 5000 # Interval steps to save checkpoint.
99+
eval_interval_steps: 2000 # Interval steps to evaluate the network.
100100
log_interval_steps: 200 # Interval steps to record the training log.
101101

102102
###########################################################
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2020 Minh Nguyen (@dathudeptrai)
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Decode trained Mb-Melgan from folder."""
16+
17+
import argparse
18+
import logging
19+
import os
20+
21+
import numpy as np
22+
import soundfile as sf
23+
import yaml
24+
from tqdm import tqdm
25+
26+
from tensorflow_tts.configs import ParallelWaveGANGeneratorConfig
27+
from tensorflow_tts.datasets import MelDataset
28+
from tensorflow_tts.models import TFParallelWaveGANGenerator
29+
30+
31+
def main():
32+
"""Run parallel_wavegan decoding from folder."""
33+
parser = argparse.ArgumentParser(
34+
description="Generate Audio from melspectrogram with trained melgan "
35+
"(See detail in examples/parallel_wavegan/decode_parallel_wavegan.py)."
36+
)
37+
parser.add_argument(
38+
"--rootdir",
39+
default=None,
40+
type=str,
41+
required=True,
42+
help="directory including ids/durations files.",
43+
)
44+
parser.add_argument(
45+
"--outdir", type=str, required=True, help="directory to save generated speech."
46+
)
47+
parser.add_argument(
48+
"--checkpoint", type=str, required=True, help="checkpoint file to be loaded."
49+
)
50+
parser.add_argument(
51+
"--use-norm", type=int, default=1, help="Use norm or raw melspectrogram."
52+
)
53+
parser.add_argument("--batch-size", type=int, default=8, help="batch_size.")
54+
parser.add_argument(
55+
"--config",
56+
default=None,
57+
type=str,
58+
required=True,
59+
help="yaml format configuration file. if not explicitly provided, "
60+
"it will be searched in the checkpoint directory. (default=None)",
61+
)
62+
parser.add_argument(
63+
"--verbose",
64+
type=int,
65+
default=1,
66+
help="logging level. higher is more logging. (default=1)",
67+
)
68+
args = parser.parse_args()
69+
70+
# set logger
71+
if args.verbose > 1:
72+
logging.basicConfig(
73+
level=logging.DEBUG,
74+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
75+
)
76+
elif args.verbose > 0:
77+
logging.basicConfig(
78+
level=logging.INFO,
79+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
80+
)
81+
else:
82+
logging.basicConfig(
83+
level=logging.WARN,
84+
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
85+
)
86+
logging.warning("Skip DEBUG/INFO messages")
87+
88+
# check directory existence
89+
if not os.path.exists(args.outdir):
90+
os.makedirs(args.outdir)
91+
92+
# load config
93+
with open(args.config) as f:
94+
config = yaml.load(f, Loader=yaml.Loader)
95+
config.update(vars(args))
96+
97+
if config["format"] == "npy":
98+
mel_query = "*-fs-after-feats.npy" if "fastspeech" in args.rootdir else "*-norm-feats.npy" if args.use_norm == 1 else "*-raw-feats.npy"
99+
mel_load_fn = np.load
100+
else:
101+
raise ValueError("Only npy is supported.")
102+
103+
# define data-loader
104+
dataset = MelDataset(
105+
root_dir=args.rootdir,
106+
mel_query=mel_query,
107+
mel_load_fn=mel_load_fn,
108+
)
109+
dataset = dataset.create(batch_size=args.batch_size)
110+
111+
# define model and load checkpoint
112+
parallel_wavegan = TFParallelWaveGANGenerator(
113+
config=ParallelWaveGANGeneratorConfig(**config["parallel_wavegan_generator_params"]),
114+
name="parallel_wavegan_generator",
115+
)
116+
parallel_wavegan._build()
117+
parallel_wavegan.load_weights(args.checkpoint)
118+
119+
for data in tqdm(dataset, desc="[Decoding]"):
120+
utt_ids, mels, mel_lengths = data["utt_ids"], data["mels"], data["mel_lengths"]
121+
122+
# pwgan inference.
123+
generated_audios = parallel_wavegan.inference(generated_subbands)
124+
125+
# convert to numpy.
126+
generated_audios = generated_audios.numpy() # [B, T]
127+
128+
# save to outdir
129+
for i, audio in enumerate(generated_audios):
130+
utt_id = utt_ids[i].numpy().decode("utf-8")
131+
sf.write(
132+
os.path.join(args.outdir, f"{utt_id}.wav"),
133+
audio[: mel_lengths[i].numpy() * config["hop_size"]],
134+
config["sampling_rate"],
135+
"PCM_16",
136+
)
137+
138+
139+
if __name__ == "__main__":
140+
main()

0 commit comments

Comments
 (0)