Skip to content

Commit d68da28

Browse files
authored
Export MeloTTS to ONNX (k2-fsa#1129)
1 parent dd5abd8 commit d68da28

File tree

4 files changed

+573
-0
lines changed

4 files changed

+573
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
name: export-melo-tts-to-onnx
2+
3+
on:
4+
push:
5+
branches:
6+
- export-melo-tts-onnx
7+
workflow_dispatch:
8+
9+
concurrency:
10+
group: export-melo-tts-to-onnx-${{ github.ref }}
11+
cancel-in-progress: true
12+
13+
jobs:
14+
export-melo-tts-to-onnx:
15+
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
16+
name: export melo-tts
17+
runs-on: ${{ matrix.os }}
18+
strategy:
19+
fail-fast: false
20+
matrix:
21+
os: [ubuntu-latest]
22+
python-version: ["3.10"]
23+
24+
steps:
25+
- uses: actions/checkout@v4
26+
27+
- name: Setup Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
32+
- name: Run
33+
shell: bash
34+
run: |
35+
cd scripts/melo-tts
36+
./run.sh
37+
38+
- uses: actions/upload-artifact@v4
39+
with:
40+
name: test.wav
41+
path: scripts/melo-tts/test.wav
42+
43+
- name: Publish to huggingface (aishell)
44+
env:
45+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
46+
uses: nick-fields/retry@v3
47+
with:
48+
max_attempts: 20
49+
timeout_seconds: 200
50+
shell: bash
51+
command: |
52+
git config --global user.email "[email protected]"
53+
git config --global user.name "Fangjun Kuang"
54+
55+
rm -rf huggingface
56+
export GIT_LFS_SKIP_SMUDGE=1
57+
export GIT_CLONE_PROTECTION_ACTIVE=false
58+
59+
git clone https://huggingface.co/csukuangfj/vits-melo-tts-zh_en huggingface
60+
cd huggingface
61+
git fetch
62+
git pull
63+
echo "pwd: $PWD"
64+
ls -lh ../scripts/melo-tts
65+
66+
cp -v ../scripts/melo-tts/*.onnx .
67+
cp -v ../scripts/melo-tts/lexicon.txt .
68+
cp -v ../scripts/melo-tts/tokens.txt .
69+
70+
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/date.fst
71+
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/number.fst
72+
curl -SL -O https://huggingface.co/csukuangfj/icefall-tts-aishell3-vits-low-2024-04-06/resolve/main/data/phone.fst
73+
curl -SL -O https://github.com/csukuangfj/cppjieba/releases/download/sherpa-onnx-2024-04-19/dict.tar.bz2
74+
tar xvf dict.tar.bz2
75+
rm dict.tar.bz2
76+
77+
git lfs track "*.onnx"
78+
git add .
79+
80+
git commit -m "add models"
81+
git push https://csukuangfj:[email protected]/csukuangfj/vits-melo-tts-zh_en main || true
82+
83+
cd ..
84+
85+
rm -rf huggingface/.git*
86+
dst=vits-melo-tts-zh_en
87+
88+
mv huggingface $dst
89+
90+
tar cjvf $dst.tar.bz2 $dst
91+
rm -rf $dst
92+
93+
- name: Release
94+
uses: svenstaro/upload-release-action@v2
95+
with:
96+
file_glob: true
97+
file: ./*.tar.bz2
98+
overwrite: true
99+
repo_name: k2-fsa/sherpa-onnx
100+
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
101+
tag: tts-models

scripts/melo-tts/export-onnx.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
#!/usr/bin/env python3
2+
from typing import Any, Dict
3+
4+
import onnx
5+
import torch
6+
from melo.api import TTS
7+
from melo.text import language_id_map, language_tone_start_map
8+
from melo.text.chinese import pinyin_to_symbol_map
9+
from pypinyin import Style, lazy_pinyin, phrases_dict, pinyin_dict
10+
11+
for k, v in pinyin_to_symbol_map.items():
12+
pinyin_to_symbol_map[k] = v.split()
13+
14+
15+
def get_initial_final_tone(word: str):
16+
initials = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.INITIALS)
17+
finals = lazy_pinyin(word, neutral_tone_with_five=True, style=Style.FINALS_TONE3)
18+
19+
ans_phone = []
20+
ans_tone = []
21+
22+
for c, v in zip(initials, finals):
23+
raw_pinyin = c + v
24+
v_without_tone = v[:-1]
25+
try:
26+
tone = v[-1]
27+
except:
28+
print("skip", word, initials, finals)
29+
return [], []
30+
31+
pinyin = c + v_without_tone
32+
assert tone in "12345"
33+
34+
if c:
35+
v_rep_map = {
36+
"uei": "ui",
37+
"iou": "iu",
38+
"uen": "un",
39+
}
40+
if v_without_tone in v_rep_map.keys():
41+
pinyin = c + v_rep_map[v_without_tone]
42+
else:
43+
pinyin_rep_map = {
44+
"ing": "ying",
45+
"i": "yi",
46+
"in": "yin",
47+
"u": "wu",
48+
}
49+
if pinyin in pinyin_rep_map.keys():
50+
pinyin = pinyin_rep_map[pinyin]
51+
else:
52+
single_rep_map = {
53+
"v": "yu",
54+
"e": "e",
55+
"i": "y",
56+
"u": "w",
57+
}
58+
if pinyin[0] in single_rep_map.keys():
59+
pinyin = single_rep_map[pinyin[0]] + pinyin[1:]
60+
# print(word, initials, finals, pinyin)
61+
62+
if pinyin not in pinyin_to_symbol_map:
63+
print("skip", pinyin, word, c, v, raw_pinyin)
64+
continue
65+
phone = pinyin_to_symbol_map[pinyin]
66+
ans_phone += phone
67+
ans_tone += [tone] * len(phone)
68+
69+
return ans_phone, ans_tone
70+
71+
72+
def generate_tokens(symbol_list):
73+
with open("tokens.txt", "w", encoding="utf-8") as f:
74+
for i, s in enumerate(symbol_list):
75+
f.write(f"{s} {i}\n")
76+
77+
78+
def generate_lexicon():
79+
word_dict = pinyin_dict.pinyin_dict
80+
phrases = phrases_dict.phrases_dict
81+
with open("lexicon.txt", "w", encoding="utf-8") as f:
82+
for key in word_dict:
83+
if not (0x4E00 <= key <= 0x9FA5):
84+
continue
85+
w = chr(key)
86+
phone, tone = get_initial_final_tone(w)
87+
if not phone:
88+
continue
89+
phone = " ".join(phone)
90+
tone = " ".join(tone)
91+
f.write(f"{w} {phone} {tone}\n")
92+
93+
for w in phrases:
94+
phone, tone = get_initial_final_tone(w)
95+
if not phone:
96+
continue
97+
assert len(phone) == len(tone), (len(phone), len(tone), phone, tone)
98+
phone = " ".join(phone)
99+
tone = " ".join(tone)
100+
f.write(f"{w} {phone} {tone}\n")
101+
102+
103+
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
104+
"""Add meta data to an ONNX model. It is changed in-place.
105+
106+
Args:
107+
filename:
108+
Filename of the ONNX model to be changed.
109+
meta_data:
110+
Key-value pairs.
111+
"""
112+
model = onnx.load(filename)
113+
while len(model.metadata_props):
114+
model.metadata_props.pop()
115+
116+
for key, value in meta_data.items():
117+
meta = model.metadata_props.add()
118+
meta.key = key
119+
meta.value = str(value)
120+
121+
onnx.save(model, filename)
122+
123+
124+
class ModelWrapper(torch.nn.Module):
125+
def __init__(self, model: "SynthesizerTrn"):
126+
super().__init__()
127+
self.model = model
128+
129+
def forward(
130+
self,
131+
x,
132+
x_lengths,
133+
tones,
134+
lang_id,
135+
bert,
136+
ja_bert,
137+
sid,
138+
noise_scale,
139+
length_scale,
140+
noise_scale_w,
141+
max_len=None,
142+
):
143+
"""
144+
Args:
145+
x: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
146+
tones: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
147+
lang_id: A 1-D array of dtype np.int64. Its shape is (token_numbers,)
148+
sid: an integer
149+
"""
150+
return self.model.infer(
151+
x=x,
152+
x_lengths=x_lengths,
153+
sid=sid,
154+
tone=tones,
155+
language=lang_id,
156+
bert=bert,
157+
ja_bert=ja_bert,
158+
noise_scale=noise_scale,
159+
noise_scale_w=noise_scale_w,
160+
length_scale=length_scale,
161+
)[0]
162+
163+
164+
def main():
165+
generate_lexicon()
166+
167+
language = "ZH"
168+
model = TTS(language=language, device="cpu")
169+
170+
generate_tokens(model.hps["symbols"])
171+
172+
torch_model = ModelWrapper(model.model)
173+
174+
opset_version = 13
175+
x = torch.randint(low=0, high=10, size=(60,), dtype=torch.int64)
176+
print(x.shape)
177+
x_lengths = torch.tensor([x.size(0)], dtype=torch.int64)
178+
sid = torch.tensor([1], dtype=torch.int64)
179+
tones = torch.zeros_like(x)
180+
lang_id = torch.ones_like(x)
181+
noise_scale = torch.tensor([1.0], dtype=torch.float32)
182+
length_scale = torch.tensor([1.0], dtype=torch.float32)
183+
noise_scale_w = torch.tensor([1.0], dtype=torch.float32)
184+
185+
bert = torch.zeros(1024, x.shape[0], dtype=torch.float32)
186+
ja_bert = torch.zeros(768, x.shape[0], dtype=torch.float32)
187+
188+
x = x.unsqueeze(0)
189+
tones = tones.unsqueeze(0)
190+
lang_id = lang_id.unsqueeze(0)
191+
bert = bert.unsqueeze(0)
192+
ja_bert = ja_bert.unsqueeze(0)
193+
194+
filename = "model.onnx"
195+
196+
torch.onnx.export(
197+
torch_model,
198+
(
199+
x,
200+
x_lengths,
201+
tones,
202+
lang_id,
203+
bert,
204+
ja_bert,
205+
sid,
206+
noise_scale,
207+
length_scale,
208+
noise_scale_w,
209+
),
210+
filename,
211+
opset_version=opset_version,
212+
input_names=[
213+
"x",
214+
"x_lengths",
215+
"tones",
216+
"lang_id",
217+
"bert",
218+
"ja_bert",
219+
"sid",
220+
"noise_scale",
221+
"length_scale",
222+
"noise_scale_w",
223+
],
224+
output_names=["y"],
225+
dynamic_axes={
226+
"x": {0: "N", 1: "L"},
227+
"x_lengths": {0: "N"},
228+
"tones": {0: "N", 1: "L"},
229+
"lang_id": {0: "N", 1: "L"},
230+
"bert": {0: "N", 2: "L"},
231+
"ja_bert": {0: "N", 2: "L"},
232+
"y": {0: "N", 1: "S", 2: "T"},
233+
},
234+
)
235+
236+
meta_data = {
237+
"model_type": "melo-vits",
238+
"comment": "melo",
239+
"language": "Chinese + English",
240+
"add_blank": int(model.hps.data.add_blank),
241+
"n_speakers": 1,
242+
"sample_rate": model.hps.data.sampling_rate,
243+
"bert_dim": 1024,
244+
"ja_bert_dim": 768,
245+
"speaker_id": list(model.hps.data.spk2id.values())[0],
246+
"lang_id": language_id_map[model.language],
247+
"tone_start": language_tone_start_map[model.language],
248+
"url": "https://github.com/myshell-ai/MeloTTS",
249+
"license": "MIT license",
250+
"description": "MeloTTS is a high-quality multi-lingual text-to-speech library by MyShell.ai",
251+
}
252+
add_meta_data(filename, meta_data)
253+
254+
255+
if __name__ == "__main__":
256+
main()

0 commit comments

Comments
 (0)