Skip to content

Commit 87f67d2

Browse files
committed
📝 Add abstractmethod 'setup_eos_token' to base_processor.
1 parent 63d4220 commit 87f67d2

File tree

6 files changed

+20
-1
lines changed

6 files changed

+20
-1
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# TODO(@dathudeptrai) update requirement if needed.
2323
requirements = {
2424
"install": [
25-
"tensorflow-gpu==2.3.0",
25+
"tensorflow-gpu>=2.2.0",
2626
"tensorflow-addons>=0.10.0",
2727
"setuptools>=38.5.1",
2828
"librosa>=0.7.0",

tensorflow_tts/processor/baker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,9 @@ def __post_init__(self):
549549
super().__post_init__()
550550
self.pinyin_parser = self.get_pinyin_parser()
551551

552+
def setup_eos_token(self):
553+
return _eos[0]
554+
552555
def create_items(self):
553556
items = []
554557
if self.data_dir:

tensorflow_tts/processor/base_processor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __post_init__(self):
6464

6565
# processor name. usefull to use it for AutoProcessor
6666
self._processor_name = type(self).__name__
67+
self.eos_id = self.symbol_to_id[self.setup_eos_token()]
6768

6869
def __getattr__(self, name: str) -> Union[str, int]:
6970
if "_id" in name: # map symbol to id
@@ -151,6 +152,12 @@ def get_one_sample(self, item):
151152
def text_to_sequence(self, text: str):
152153
return []
153154

155+
@abc.abstractmethod
156+
def setup_eos_token(self):
157+
"""Return eos symbol of type string."""
158+
return "eos"
159+
160+
154161
def convert_symbols_to_ids(self, symbols: Union[str, list]):
155162
sequence = []
156163
if isinstance(symbols, str):

tensorflow_tts/processor/kss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def split_line(self, data_dir, line, split):
5454
speaker_name = "kss"
5555
return text_norm, wav_path, speaker_name
5656

57+
def setup_eos_token(self):
58+
return "eos"
59+
5760
def get_one_sample(self, item):
5861
text, wav_path, speaker_name = item
5962

tensorflow_tts/processor/libritts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def get_one_sample(self, item):
8383

8484
return sample
8585

86+
def setup_eos_token(self):
87+
return None
88+
8689
def text_to_sequence(self, text):
8790
if (
8891
self.mode == "train"

tensorflow_tts/processor/ljspeech.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def split_line(self, data_dir, line, split):
155155
speaker_name = "ljspeech"
156156
return text_norm, wav_path, speaker_name
157157

158+
def setup_eos_token(self):
159+
return _eos
160+
158161
def get_one_sample(self, item):
159162
text, wav_path, speaker_name = item
160163

0 commit comments

Comments
 (0)