Skip to content

Commit 6b2ad08

Browse files
authored
Merge pull request #4 from PyThaiNLP/add-tts-thai
Add lunarlist_model
2 parents d8839c1 + a70fdc0 commit 6b2ad08

File tree

6 files changed

+477
-6
lines changed

6 files changed

+477
-6
lines changed

notebook/use_lunarlist_model.ipynb

Lines changed: 414 additions & 0 deletions
Large diffs are not rendered by default.

pythaitts/__init__.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
"""
33
PyThaiTTS
44
"""
5-
__version__ = "0.1.1"
5+
__version__ = "0.2.0"
66

77

88
class TTS:
9-
def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0") -> None:
9+
def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0", device:str="cpu") -> None:
1010
"""
11-
:param str pretrained: TTS pretrained (khanomtan)
11+
:param str pretrained: TTS pretrained (khanomtan, lunarlist)
1212
:param str mode: pretrained mode
1313
:param str version: model version (default is 1.0 or 1.1)
1414
@@ -18,9 +18,14 @@ def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0"
1818
1919
You can see more about khanomtan tts at `https://github.com/wannaphong/KhanomTan-TTS-v1.0 <https://github.com/wannaphong/KhanomTan-TTS-v1.0>`_
2020
and `https://github.com/wannaphong/KhanomTan-TTS-v1.1 <https://github.com/wannaphong/KhanomTan-TTS-v1.1>`_
21+
22+
For lunarlist tts model, you must to install nemo before use the model by pip install nemo_toolkit['tts'].
23+
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
24+
2125
"""
2226
self.pretrained = pretrained
2327
self.mode = mode
28+
self.device = device
2429
self.load_pretrained(version=version)
2530

2631
def load_pretrained(self,version):
@@ -30,6 +35,9 @@ def load_pretrained(self,version):
3035
if self.pretrained == "khanomtan":
3136
from pythaitts.pretrained import KhanomTan
3237
self.model = KhanomTan(mode=self.mode, version=version)
38+
elif self.pretrained == "lunarlist":
39+
from pythaitts.pretrained import LunarlistModel
40+
self.model = LunarlistModel(mode=self.mode, device=self.device)
3341
else:
3442
raise NotImplemented(
3543
"PyThaiTTS doesn't support %s pretrained." % self.pretrained
@@ -45,6 +53,8 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
4553
:param str return_type: return type (default is file)
4654
:param str filename: path filename for save wav file if return_type is file.
4755
"""
56+
if self.pretrained == "lunarlist":
57+
return self.model(text=text,return_type=return_type,filename=filename)
4858
return self.model(
4959
text=text,
5060
speaker_idx=speaker_idx,

pythaitts/pretrained/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# -*- coding: utf-8 -*-
22
from pythaitts.pretrained.khanomtan_tts import KhanomTan
3+
from pythaitts.pretrained.lunarlist_model import LunarlistModel
34

45
__all__ = [
5-
"KhanomTan"
6+
"KhanomTan",
7+
"LunarlistModel"
68
]
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Lunarlist TTS model
4+
5+
You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
6+
"""
7+
import tempfile
8+
import torch
9+
10+
11+
class LunarlistModel:
12+
def __init__(self, mode:str="last_checkpoint", device:str="cpu") -> None:
13+
try:
14+
from nemo.collections.tts.models import UnivNetModel
15+
except ImportError:
16+
raise ImportError("You must to install nemo by pip install nemo_toolkit['tts'] before use this model.")
17+
self.mode = mode
18+
self.device = device
19+
self.vcoder_model = UnivNetModel.from_pretrained(model_name="tts_en_libritts_univnet").to(self.device)
20+
self.load_synthesizer(self.mode)
21+
def load_synthesizer(self, mode:str):
22+
from nemo.collections.tts.models import Tacotron2Model
23+
if mode=="last_checkpoint":
24+
self.model = Tacotron2Model.from_pretrained("lunarlist/tts-thai-last-step").to(self.device)
25+
else:
26+
self.model = Tacotron2Model.from_pretrained("lunarlist/tts-thai").to(self.device)
27+
self.dict_idx={k:i for i,k in enumerate(self.model.hparams["cfg"]['labels'])}
28+
def tts(self, text: str):
29+
parsed2=torch.Tensor([[66]+[self.dict_idx[i] for i in text if i]+[67]]).int().to(self.device)
30+
spectrogram2 = self.model.generate_spectrogram(tokens=parsed2)
31+
audio2 = self.vcoder_model.convert_spectrogram_to_audio(spec=spectrogram2)
32+
return audio2.to('cpu').detach().numpy()
33+
def __call__(self, text: str,return_type: str = "file", filename: str = None):
34+
wavs = self.tts(text)
35+
if return_type == "waveform":
36+
return wavs
37+
import soundfile as sf
38+
if filename != None:
39+
sf.write(filename, wavs[0], 22050)
40+
return filename
41+
else:
42+
with tempfile.NamedTemporaryFile(suffix = ".wav", delete = False) as fp:
43+
fp.write(wavs[0])
44+
return fp.name

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
TTS>=0.8.0
22
pythainlp>=3.0.0
3-
huggingface_hub
3+
huggingface_hub
4+
torch

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name="PyThaiTTS",
12-
version="0.1.1",
12+
version="0.2.0",
1313
description="Open Source Thai Text-to-speech library in Python",
1414
long_description=readme,
1515
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)