Skip to content

Commit ebd5209

Browse files
committed
🐸 Fix pytest, add 'extra_attrs_to_save' to _save_mapper, load other keys in _load_mapper.
1 parent 2182343 commit ebd5209

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

tensorflow_tts/configs/fastspeech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import collections
1818

19-
from tensorflow_tts.processor.ljspeech import symbols as lj_symbols
20-
from tensorflow_tts.processor.kss import symbols as kss_symbols
19+
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
20+
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
2121
from tensorflow_tts.processor.baker import symbols as bk_symbols
2222

2323

tensorflow_tts/configs/tacotron2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
# limitations under the License.
1515
"""Tacotron-2 Config object."""
1616

17-
from tensorflow_tts.processor.ljspeech import symbols as lj_symbols
18-
from tensorflow_tts.processor.kss import symbols as kss_symbols
17+
from tensorflow_tts.processor.ljspeech import LJSPEECH_SYMBOLS as lj_symbols
18+
from tensorflow_tts.processor.kss import KSS_SYMBOLS as kss_symbols
1919
from tensorflow_tts.processor.baker import symbols as bk_symbols
2020

2121

tensorflow_tts/processor/base_processor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,13 @@ def _load_mapper(self, loaded_path: str = None):
178178
self.symbol_to_id = data["symbol_to_id"]
179179
self.id_to_symbol = {int(k): v for k, v in data["id_to_symbol"].items()}
180180

181-
def _save_mapper(self, saved_path: str = None):
181+
# other keys
182+
all_data_keys = data.keys()
183+
for key in all_data_keys:
184+
if key not in ["speakers_map", "symbol_to_id", "id_to_symbol"]:
185+
setattr(self, key, data[key])
186+
187+
def _save_mapper(self, saved_path: str = None, extra_attrs_to_save: dict = None):
182188
"""
183189
Save all needed mappers to file
184190
"""
@@ -193,4 +199,6 @@ def _save_mapper(self, saved_path: str = None):
193199
"id_to_symbol": self.id_to_symbol,
194200
"speakers_map": self.speakers_map,
195201
}
202+
if extra_attrs_to_save:
203+
full_mapper = {**full_mapper, **extra_attrs_to_save}
196204
json.dump(full_mapper, f)

0 commit comments

Comments
 (0)