Skip to content

Commit b7223d4

Browse files
junkinrmittal-github
authored andcommitted
feat: s2s streaming demo app
1 parent 22438ab commit b7223d4

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

scripts/nmt/s2s_mic.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
import argparse
5+
import wave
6+
import riva.client
7+
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
8+
from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
9+
import riva.client.audio_io
10+
import riva.client.proto.riva_nmt_pb2 as riva_nmt
11+
12+
def parse_args() -> argparse.Namespace:
13+
default_device_info = riva.client.audio_io.get_default_input_device_info()
14+
default_device_index = None if default_device_info is None else default_device_info['index']
15+
parser = argparse.ArgumentParser(
16+
description="Streaming speech to speech translation from microphone via Riva AI Services",
17+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18+
)
19+
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.")
20+
parser.add_argument("--list-input-devices", action="store_true", help="List input audio device indices.")
21+
parser.add_argument("--list-output-devices", action="store_true", help="List input audio device indices.")
22+
parser.add_argument("--output-device", type=int, help="Output device to use.")
23+
parser.add_argument(
24+
"--play-audio",
25+
action="store_true",
26+
help="Play input audio simultaneously with transcribing and translating it. If `--output-device` is not provided, "
27+
"then the default output audio device will be used.",
28+
)
29+
30+
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
31+
parser = add_connection_argparse_parameters(parser)
32+
parser.add_argument(
33+
"--sample-rate-hz",
34+
type=int,
35+
help="A number of frames per second in audio streamed from a microphone.",
36+
default=16000,
37+
)
38+
parser.add_argument(
39+
"--file-streaming-chunk",
40+
type=int,
41+
default=1600,
42+
help="A maximum number of frames in a audio chunk sent to server.",
43+
)
44+
args = parser.parse_args()
45+
return args
46+
47+
def play_responses(responses: Iterable[riva_nmt.StreamingTranslateSpeechToSpeechResponse],
48+
sound_stream) -> None:
49+
count = 0
50+
for response in responses:
51+
#if first:
52+
#print(f"time to first audio {(stop - start):.3f}s")
53+
# first=False
54+
if sound_stream is not None:
55+
sound_stream(response.speech.audio)
56+
fname = "response" + str(count)
57+
out_f = wave.open(fname, 'wb')
58+
out_f.setnchannels(1)
59+
out_f.setsampwidth(2)
60+
out_f.setframerate(44100)
61+
count += 1
62+
63+
64+
def main() -> None:
65+
args = parse_args()
66+
sound_stream = None
67+
sampwidth = 2
68+
nchannels = 1
69+
if args.list_input_devices:
70+
riva.client.audio_io.list_input_devices()
71+
return
72+
if args.output_device is not None or args.play_audio:
73+
print("playing audio")
74+
sound_stream = riva.client.audio_io.SoundCallBack(
75+
args.output_device, nchannels=nchannels, sampwidth=sampwidth, framerate=44100
76+
)
77+
print(sound_stream)
78+
first = True # first tts output chunk received
79+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
80+
nmt_service = riva.client.NeuralMachineTranslationClient(auth)
81+
s2s_config = riva.client.StreamingTranslateSpeechToSpeechConfig(
82+
asrConfig = riva.client.StreamingRecognitionConfig(
83+
config=riva.client.RecognitionConfig(
84+
encoding=riva.client.AudioEncoding.LINEAR_PCM,
85+
language_code=args.language_code,
86+
max_alternatives=1,
87+
profanity_filter=args.profanity_filter,
88+
enable_automatic_punctuation=args.automatic_punctuation,
89+
verbatim_transcripts=not args.no_verbatim_transcripts,
90+
sample_rate_hertz=args.sample_rate_hz,
91+
audio_channel_count=1,
92+
),
93+
interim_results=True,
94+
)
95+
)
96+
97+
#riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
98+
with riva.client.audio_io.MicrophoneStream(
99+
args.sample_rate_hz,
100+
args.file_streaming_chunk,
101+
device=args.input_device,
102+
) as audio_chunk_iterator:
103+
play_responses(responses=nmt_service.streaming_s2s_response_generator(
104+
audio_chunks=audio_chunk_iterator,
105+
streaming_config=s2s_config), sound_stream=sound_stream)
106+
# if first:
107+
# first = False
108+
# if sound_stream is not None:
109+
# sound_stream(response.audio)
110+
111+
112+
113+
if __name__ == '__main__':
114+
main()

0 commit comments

Comments
 (0)