Skip to content

Commit b32b08f

Browse files
Support custom dictionary param for TTS client (#82)
* tts: add user_dictionary * update py clients * rename variable * correct argument name * update description and fix space while joining list * update common repo SHA * remove unused imports and fixes * update common SHA --------- Co-authored-by: rmittal-github <[email protected]> Co-authored-by: Rahul Mittal <[email protected]>
1 parent 2cfd632 commit b32b08f

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

common

riva/client/tts.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from riva.client.proto.riva_audio_pb2 import AudioEncoding
1212
import wave
1313

14+
def add_custom_dictionary_to_config(req, custom_dictionary):
15+
result_list = [f"{key} {value}" for key, value in custom_dictionary.items()]
16+
result_string = ','.join(result_list)
17+
req.custom_dictionary = result_string
18+
1419
class SpeechSynthesisService:
1520
"""
1621
A class for synthesizing speech from text. Provides :meth:`synthesize` which returns entire audio for a text
@@ -38,6 +43,7 @@ def synthesize(
3843
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
3944
quality: int = 20,
4045
future: bool = False,
46+
custom_dictionary: Optional[dict] = None,
4147
) -> Union[rtts.SynthesizeSpeechResponse, _MultiThreadedRendezvous]:
4248
"""
4349
Synthesizes an entire audio for text :param:`text`.
@@ -56,6 +62,7 @@ def synthesize(
5662
audio but also takes longer to generate the audio. Ranges between 1-40.
5763
future (:obj:`bool`, defaults to :obj:`False`): Whether to return an async result instead of usual
5864
response. You can get a response by calling ``result()`` method of the future object.
65+
custom_dictionary (:obj:`dict`, `optional`): Dictionary with key-value pair containing grapheme and corresponding phoneme
5966
6067
Returns:
6168
:obj:`Union[riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse, grpc._channel._MultiThreadedRendezvous]`:
@@ -81,6 +88,8 @@ def synthesize(
8188
req.zero_shot_data.encoding = audio_prompt_encoding
8289
req.zero_shot_data.quality = quality
8390

91+
add_custom_dictionary_to_config(req, custom_dictionary)
92+
8493
func = self.stub.Synthesize.future if future else self.stub.Synthesize
8594
return func(req, metadata=self.auth.get_auth_metadata())
8695

@@ -94,6 +103,7 @@ def synthesize_online(
94103
audio_prompt_file: Optional[str] = None,
95104
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
96105
quality: int = 20,
106+
custom_dictionary: Optional[dict] = None,
97107
) -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
98108
"""
99109
Synthesizes and yields output audio chunks for text :param:`text` as the chunks
@@ -111,6 +121,7 @@ def synthesize_online(
111121
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
112122
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
113123
audio but also takes longer to generate the audio. Ranges between 1-40.
124+
custom_dictionary (:obj:`dict`, `optional`): Dictionary with key-value pair containing grapheme and corresponding phoneme
114125
115126
Yields:
116127
:obj:`riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse`: a response with output. You may find
@@ -138,4 +149,6 @@ def synthesize_online(
138149
req.zero_shot_data.encoding = audio_prompt_encoding
139150
req.zero_shot_data.quality = quality
140151

152+
add_custom_dictionary_to_config(req, custom_dictionary)
153+
141154
return self.stub.SynthesizeOnline(req, metadata=self.auth.get_auth_metadata())

scripts/tts/talk.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,20 @@
1010
import riva.client
1111
from riva.client.argparse_utils import add_connection_argparse_parameters
1212

13+
def read_file_to_dict(file_path):
14+
result_dict = {}
15+
with open(file_path, 'r') as file:
16+
for line_number, line in enumerate(file, start=1):
17+
line = line.strip()
18+
try:
19+
key, value = line.split(' ', 1) # Split by double space
20+
result_dict[str(key.strip())] = str(value.strip())
21+
except ValueError:
22+
print(f"Warning: Malformed line {line}")
23+
continue
24+
if not result_dict:
25+
raise ValueError("Error: No valid entries found in the file.")
26+
return result_dict
1327

1428
def parse_args() -> argparse.Namespace:
1529
parser = argparse.ArgumentParser(
@@ -42,6 +56,7 @@ def parse_args() -> argparse.Namespace:
4256
parser.add_argument(
4357
"--sample-rate-hz", type=int, default=44100, help="Number of audio frames per second in synthesized audio."
4458
)
59+
parser.add_argument("--custom-dictionary", type=str, help="A file path to a user dictionary with key-value pairs separated by double spaces.")
4560
parser.add_argument(
4661
"--stream",
4762
action="store_true",
@@ -108,12 +123,17 @@ def main() -> None:
108123
out_f.setsampwidth(sampwidth)
109124
out_f.setframerate(args.sample_rate_hz)
110125

126+
custom_dictionary_input = {}
127+
if args.custom_dictionary is not None:
128+
custom_dictionary_input = read_file_to_dict(args.custom_dictionary)
129+
111130
print("Generating audio for request...")
112131
start = time.time()
113132
if args.stream:
114133
responses = service.synthesize_online(
115134
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
116-
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality
135+
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality,
136+
custom_dictionary=custom_dictionary_input
117137
)
118138
first = True
119139
for resp in responses:
@@ -128,7 +148,8 @@ def main() -> None:
128148
else:
129149
resp = service.synthesize(
130150
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
131-
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality
151+
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality,
152+
custom_dictionary=custom_dictionary_input
132153
)
133154
stop = time.time()
134155
print(f"Time spent: {(stop - start):.3f}s")

0 commit comments

Comments
 (0)