Skip to content

Commit cf276ba

Browse files
fix(tts): Added zero shot parameters to talk.py (#69)
1 parent e1145b8 commit cf276ba

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

riva/client/tts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def synthesize(
7474
if audio_prompt_file is not None:
7575
with wave.open(str(audio_prompt_file), 'rb') as wf:
7676
rate = wf.getframerate()
77-
req.zero_shot_data.sample_rate = rate
77+
req.zero_shot_data.sample_rate_hz = rate
7878
with audio_prompt_file.open('rb') as wav_f:
7979
audio_data = wav_f.read()
8080
req.zero_shot_data.audio_prompt = audio_data
@@ -131,7 +131,7 @@ def synthesize_online(
131131
if audio_prompt_file is not None:
132132
with wave.open(str(audio_prompt_file), 'rb') as wf:
133133
rate = wf.getframerate()
134-
req.zero_shot_data.sample_rate = rate
134+
req.zero_shot_data.sample_rate_hz = rate
135135
with audio_prompt_file.open('rb') as wav_f:
136136
audio_data = wav_f.read()
137137
req.zero_shot_data.audio_prompt = audio_data

scripts/tts/talk.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,12 @@ def parse_args() -> argparse.Namespace:
2222
"based on parameter `--language-code`.",
2323
)
2424
parser.add_argument("--text", type=str, required=True, help="Text input to synthesize.")
25+
parser.add_argument(
26+
"--audio_prompt_file",
27+
type=Path,
28+
help="An input audio prompt (.wav) file for zero shot model. This is required to do zero shot inferencing.")
2529
parser.add_argument("-o", "--output", type=Path, help="Output file .wav file to write synthesized audio.")
30+
parser.add_argument("--quality", type=int, help="Number of times decoder should be run on the output audio. A higher number improves quality of the produced output but introduces latencies.")
2631
parser.add_argument(
2732
"--play-audio",
2833
action="store_true",
@@ -81,7 +86,8 @@ def main() -> None:
8186
start = time.time()
8287
if args.stream:
8388
responses = service.synthesize_online(
84-
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz
89+
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
90+
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality
8591
)
8692
first = True
8793
for resp in responses:
@@ -95,7 +101,8 @@ def main() -> None:
95101
out_f.writeframesraw(resp.audio)
96102
else:
97103
resp = service.synthesize(
98-
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz
104+
args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz,
105+
audio_prompt_file=args.audio_prompt_file, quality=20 if args.quality is None else args.quality
99106
)
100107
stop = time.time()
101108
print(f"Time spent: {(stop - start):.3f}s")

0 commit comments

Comments
 (0)