Skip to content

Commit 3a03a45

Browse files
committed
📑 Add AutoProcessor Class.
1 parent aa871af commit 3a03a45

File tree

5 files changed

+75
-3
lines changed

5 files changed

+75
-3
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from tensorflow_tts.inference.auto_model import TFAutoModel
22
from tensorflow_tts.inference.auto_config import AutoConfig
3+
from tensorflow_tts.inference.auto_processor import AutoProcessor

tensorflow_tts/inference/auto_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def from_pretrained(cls, pretrained_path, **kwargs):
5656
return config_class
5757
except Exception:
5858
raise ValueError(
59-
"Unrecognized model in {}. "
60-
"Should have a `model_type` key in its config.json, or contain one of the following strings "
59+
"Unrecognized config in {}. "
60+
"Should have a `model_type` key in its config.yaml, or contain one of the following strings "
6161
"in its name: {}".format(
6262
pretrained_path, ", ".join(CONFIG_MAPPING.keys())
6363
)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# -*- coding: utf-8 -*-
2+
# Copyright 2020 The TensorFlowTTS Team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Tensorflow Auto Processor modules."""
16+
17+
import logging
18+
import json
19+
from collections import OrderedDict
20+
21+
from tensorflow_tts.processor import (
22+
LJSpeechProcessor,
23+
KSSProcessor,
24+
BakerProcessor,
25+
LibriTTSProcessor,
26+
)
27+
28+
CONFIG_MAPPING = OrderedDict(
29+
[
30+
("LJSpeechProcessor", LJSpeechProcessor),
31+
("KSSProcessor", KSSProcessor),
32+
("BakerProcessor", BakerProcessor),
33+
("LibriTTSProcessor", LibriTTSProcessor),
34+
]
35+
)
36+
37+
38+
class AutoProcessor:
39+
def __init__(self):
40+
raise EnvironmentError(
41+
"AutoProcessor is designed to be instantiated "
42+
"using the `AutoProcessor.from_pretrained(pretrained_path)` method."
43+
)
44+
45+
@classmethod
46+
def from_pretrained(cls, pretrained_path, **kwargs):
47+
with open(pretrained_path, "r") as f:
48+
config = json.load(f)
49+
50+
try:
51+
processor_name = config["processor_name"]
52+
processor_class = CONFIG_MAPPING[processor_name]
53+
processor_class = processor_class(
54+
data_dir=None, loaded_mapper_path=pretrained_path
55+
)
56+
return processor_class
57+
except Exception:
58+
raise ValueError(
59+
"Unrecognized processor in {}. "
60+
"Should have a `processor_name` key in its config.json, or contain one of the following strings "
61+
"in its name: {}".format(
62+
pretrained_path, ", ".join(CONFIG_MAPPING.keys())
63+
)
64+
)

test/files/mapper.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
"1": "b",
1414
"2": "@ph"
1515
},
16-
"processor_name": "TestProcessor"
16+
"processor_name": "KSSProcessor"
1717
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tensorflow as tf
2121

2222
from tensorflow_tts.inference import AutoConfig
23+
from tensorflow_tts.inference import AutoProcessor
2324
from tensorflow_tts.inference import TFAutoModel
2425

2526
os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -49,3 +50,9 @@
4950
def test_auto_model(config_path):
5051
config = AutoConfig.from_pretrained(pretrained_path=config_path)
5152
model = TFAutoModel.from_pretrained(config=config, pretrained_path=None)
53+
54+
55+
@pytest.fixture
56+
def test_auto_processor():
57+
processor = AutoProcessor(data_dir=None, loaded_mapper_path="./files/mapper.json")
58+
return processor

0 commit comments

Comments
 (0)