22"""
33PyThaiTTS
44"""
5- __version__ = "0.1.1 "
5+ __version__ = "0.2.0 "
66
77
88class TTS :
9- def __init__ (self , pretrained = "khanomtan" , mode = "last_checkpoint" , version = "1.0" ) -> None :
9+ def __init__ (self , pretrained = "khanomtan" , mode = "last_checkpoint" , version = "1.0" , device : str = "cpu" ) -> None :
1010 """
11- :param str pretrained: TTS pretrained (khanomtan)
11+ :param str pretrained: TTS pretrained (khanomtan, lunarlist )
1212 :param str mode: pretrained mode
1313 :param str version: model version (default is 1.0 or 1.1)
1414
@@ -18,9 +18,14 @@ def __init__(self, pretrained="khanomtan", mode="last_checkpoint", version="1.0"
1818
1919 You can see more about khanomtan tts at `https://github.com/wannaphong/KhanomTan-TTS-v1.0 <https://github.com/wannaphong/KhanomTan-TTS-v1.0>`_
2020 and `https://github.com/wannaphong/KhanomTan-TTS-v1.1 <https://github.com/wannaphong/KhanomTan-TTS-v1.1>`_
21+
22+ For lunarlist tts model, you must to install nemo before use the model by pip install nemo_toolkit['tts'].
23+ You can see more about lunarlist tts at `https://link.medium.com/OpPjQis6wBb <https://link.medium.com/OpPjQis6wBb>`_
24+
2125 """
2226 self .pretrained = pretrained
2327 self .mode = mode
28+ self .device = device
2429 self .load_pretrained (version = version )
2530
2631 def load_pretrained (self ,version ):
@@ -30,6 +35,9 @@ def load_pretrained(self,version):
3035 if self .pretrained == "khanomtan" :
3136 from pythaitts .pretrained import KhanomTan
3237 self .model = KhanomTan (mode = self .mode , version = version )
38+ elif self .pretrained == "lunarlist" :
39+ from pythaitts .pretrained import LunarlistModel
40+ self .model = LunarlistModel (mode = self .mode , device = self .device )
3341 else :
3442 raise NotImplemented (
3543 "PyThaiTTS doesn't support %s pretrained." % self .pretrained
@@ -45,6 +53,8 @@ def tts(self, text: str, speaker_idx: str = "Linda", language_idx: str = "th-th"
4553 :param str return_type: return type (default is file)
4654 :param str filename: path filename for save wav file if return_type is file.
4755 """
56+ if self .pretrained == "lunarlist" :
57+ return self .model (text = text ,return_type = return_type ,filename = filename )
4858 return self .model (
4959 text = text ,
5060 speaker_idx = speaker_idx ,
0 commit comments