Skip to content

Commit 9c33ebf

Browse files
authored
Merge pull request #211 from TensorSpeech/add_eos_abstract_method
📝 Add abstractmethod 'setup_eos_token' to base_processor.
2 parents b38187c + 7cb539b commit 9c33ebf

File tree

7 files changed

+28
-2
lines changed

7 files changed

+28
-2
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
# TODO(@dathudeptrai) update requirement if needed.
2323
requirements = {
2424
"install": [
25-
"tensorflow-gpu==2.3.0",
25+
"tensorflow-gpu>=2.2.0",
2626
"tensorflow-addons>=0.10.0",
2727
"setuptools>=38.5.1",
2828
"librosa>=0.7.0",

tensorflow_tts/processor/baker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,9 @@ def __post_init__(self):
549549
super().__post_init__()
550550
self.pinyin_parser = self.get_pinyin_parser()
551551

552+
def setup_eos_token(self):
553+
return _eos[0]
554+
552555
def create_items(self):
553556
items = []
554557
if self.data_dir:

tensorflow_tts/processor/base_processor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,16 @@ def __post_init__(self):
6161
self.create_symbols()
6262
if self.saved_mapper_path is not None:
6363
self._save_mapper(saved_path=self.saved_mapper_path)
64-
64+
6565
# processor name. usefull to use it for AutoProcessor
6666
self._processor_name = type(self).__name__
6767

68+
if self.setup_eos_token():
69+
self.add_symbol(
70+
self.setup_eos_token()
71+
) # if this eos token not yet present in symbols list.
72+
self.eos_id = self.symbol_to_id[self.setup_eos_token()]
73+
6874
def __getattr__(self, name: str) -> Union[str, int]:
6975
if "_id" in name: # map symbol to id
7076
return self.symbol_to_id[name.replace("_id", "")]
@@ -151,6 +157,11 @@ def get_one_sample(self, item):
151157
def text_to_sequence(self, text: str):
152158
return []
153159

160+
@abc.abstractmethod
161+
def setup_eos_token(self):
162+
"""Return eos symbol of type string."""
163+
return "eos"
164+
154165
def convert_symbols_to_ids(self, symbols: Union[str, list]):
155166
sequence = []
156167
if isinstance(symbols, str):

tensorflow_tts/processor/kss.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def split_line(self, data_dir, line, split):
5454
speaker_name = "kss"
5555
return text_norm, wav_path, speaker_name
5656

57+
def setup_eos_token(self):
58+
return "eos"
59+
5760
def get_one_sample(self, item):
5861
text, wav_path, speaker_name = item
5962

tensorflow_tts/processor/libritts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def get_one_sample(self, item):
8383

8484
return sample
8585

86+
def setup_eos_token(self):
87+
return None # because we do not use this
88+
8689
def text_to_sequence(self, text):
8790
if (
8891
self.mode == "train"

tensorflow_tts/processor/ljspeech.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ def split_line(self, data_dir, line, split):
155155
speaker_name = "ljspeech"
156156
return text_norm, wav_path, speaker_name
157157

158+
def setup_eos_token(self):
159+
return _eos
160+
158161
def get_one_sample(self, item):
159162
text, wav_path, speaker_name = item
160163

test/test_base_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def get_one_sample(self, item):
2121
def text_to_sequence(self, text):
2222
return ["0"]
2323

24+
def setup_eos_token(self):
25+
return None
26+
2427

2528
@pytest.fixture
2629
def processor(tmpdir):

0 commit comments

Comments
 (0)