Skip to content

Commit 8814150

Browse files
feat(tts): Adding support for zero shot model (#63)
* feat(tts): Adding support for pflow model input. * chore(tts): Updated SHA for common github repo
1 parent 153ebf0 commit 8814150

File tree

2 files changed

+47
-12
lines changed

2 files changed

+47
-12
lines changed

riva/client/tts.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import riva.client.proto.riva_tts_pb2_grpc as rtts_srv
1010
from riva.client import Auth
1111
from riva.client.proto.riva_audio_pb2 import AudioEncoding
12-
12+
import wave
1313

1414
class SpeechSynthesisService:
1515
"""
@@ -34,20 +34,27 @@ def synthesize(
3434
language_code: str = 'en-US',
3535
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
3636
sample_rate_hz: int = 44100,
37+
audio_prompt_file: Optional[str] = None,
38+
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
39+
quality: int = 20,
3740
future: bool = False,
3841
) -> Union[rtts.SynthesizeSpeechResponse, _MultiThreadedRendezvous]:
3942
"""
4043
Synthesizes an entire audio for text :param:`text`.
4144
4245
Args:
43-
text (:obj:`str`): an input text.
44-
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
46+
text (:obj:`str`): An input text.
47+
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
4548
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
4649
a server will select the first available model with correct :param:`language_code` value.
4750
language_code (:obj:`str`): a language to use.
48-
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
49-
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
50-
future (:obj:`bool`, defaults to :obj:`False`): whether to return an async result instead of usual
51+
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
52+
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
53+
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
54+
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
55+
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
56+
audio but also takes longer to generate the audio. Ranges between 1-40.
57+
future (:obj:`bool`, defaults to :obj:`False`): Whether to return an async result instead of usual
5158
response. You can get a response by calling ``result()`` method of the future object.
5259
5360
Returns:
@@ -64,6 +71,16 @@ def synthesize(
6471
)
6572
if voice_name is not None:
6673
req.voice_name = voice_name
74+
if audio_prompt_file is not None:
75+
with wave.open(str(audio_prompt_file), 'rb') as wf:
76+
rate = wf.getframerate()
77+
req.zero_shot_data.sample_rate = rate
78+
with audio_prompt_file.open('rb') as wav_f:
79+
audio_data = wav_f.read()
80+
req.zero_shot_data.audio_prompt = audio_data
81+
req.zero_shot_data.encoding = audio_prompt_encoding
82+
req.zero_shot_data.quality = quality
83+
6784
func = self.stub.Synthesize.future if future else self.stub.Synthesize
6885
return func(req, metadata=self.auth.get_auth_metadata())
6986

@@ -74,19 +91,26 @@ def synthesize_online(
7491
language_code: str = 'en-US',
7592
encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
7693
sample_rate_hz: int = 44100,
94+
audio_prompt_file: Optional[str] = None,
95+
audio_prompt_encoding: AudioEncoding = AudioEncoding.LINEAR_PCM,
96+
quality: int = 20,
7797
) -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
7898
"""
7999
Synthesizes and yields output audio chunks for text :param:`text` as the chunks
80100
becoming available.
81101
82102
Args:
83-
text (:obj:`str`): an input text.
84-
voice_name (:obj:`str`, `optional`): a name of the voice, e.g. ``"English-US-Female-1"``. You may find
103+
text (:obj:`str`): An input text.
104+
voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find
85105
available voices in server logs or in server model directory. If this parameter is :obj:`None`, then
86106
a server will select the first available model with correct :param:`language_code` value.
87-
language_code (:obj:`str`): a language to use.
88-
encoding (:obj:`AudioEncoding`): an output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
89-
sample_rate_hz (:obj:`int`): number of frames per second in output audio.
107+
language_code (:obj:`str`): A language to use.
108+
encoding (:obj:`AudioEncoding`): An output audio encoding, e.g. ``AudioEncoding.LINEAR_PCM``.
109+
sample_rate_hz (:obj:`int`): Number of frames per second in output audio.
110+
audio_prompt_file (:obj:`str`): An audio prompt file location for zero shot model.
111+
audio_prompt_encoding: (:obj:`AudioEncoding`): Encoding of audio prompt file, e.g. ``AudioEncoding.LINEAR_PCM``.
112+
quality: (:obj:`int`): This defines the number of times decoder is run. Higher number improves quality of generated
113+
audio but also takes longer to generate the audio. Ranges between 1-40.
90114
91115
Yields:
92116
:obj:`riva.client.proto.riva_tts_pb2.SynthesizeSpeechResponse`: a response with output. You may find
@@ -103,4 +127,15 @@ def synthesize_online(
103127
)
104128
if voice_name is not None:
105129
req.voice_name = voice_name
130+
131+
if audio_prompt_file is not None:
132+
with wave.open(str(audio_prompt_file), 'rb') as wf:
133+
rate = wf.getframerate()
134+
req.zero_shot_data.sample_rate = rate
135+
with audio_prompt_file.open('rb') as wav_f:
136+
audio_data = wav_f.read()
137+
req.zero_shot_data.audio_prompt = audio_data
138+
req.zero_shot_data.encoding = audio_prompt_encoding
139+
req.zero_shot_data.quality = quality
140+
106141
return self.stub.SynthesizeOnline(req, metadata=self.auth.get_auth_metadata())

0 commit comments

Comments
 (0)