Skip to content

Commit b11a9ad

Browse files
junkinrmittal-github
authored andcommitted
feat: s2s streaming demo app
1 parent 2d3a719 commit b11a9ad

File tree

3 files changed

+166
-2
lines changed

3 files changed

+166
-2
lines changed

riva/client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
3737
from riva.client.proto.riva_audio_pb2 import AudioEncoding
3838
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
39+
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig
3940
from riva.client.tts import SpeechSynthesisService
4041
from riva.client.nmt import NeuralMachineTranslationClient

riva/client/nmt.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: MIT
33

4-
from typing import Generator, Optional, Union, List
5-
4+
from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
65
from grpc._channel import _MultiThreadedRendezvous
76

87
import riva.client.proto.riva_nmt_pb2 as riva_nmt
98
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
109
from riva.client import Auth
1110

11+
def streaming_s2s_request_generator(
12+
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
13+
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechRequest, None, None]:
14+
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(config=streaming_config)
15+
for chunk in audio_chunks:
16+
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(audio_content=chunk)
17+
1218

1319
class NeuralMachineTranslationClient:
1420
"""
@@ -25,6 +31,49 @@ def __init__(self, auth: Auth) -> None:
2531
self.auth = auth
2632
self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel)
2733

34+
def streaming_s2s_response_generator(
35+
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
36+
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechResponse, None, None]:
37+
"""
38+
Generates speech recognition responses for fragments of speech audio in :param:`audio_chunks`.
39+
The purpose of the method is to perform speech recognition "online" - as soon as
40+
audio is acquired on small chunks of audio.
41+
42+
All available audio chunks will be sent to a server on first ``next()`` call.
43+
44+
Args:
45+
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
46+
of speech. For example, such raw audio can be obtained with
47+
48+
.. code-block:: python
49+
50+
import wave
51+
with wave.open(file_name, 'rb') as wav_f:
52+
raw_audio = wav_f.readframes(n_frames)
53+
54+
streaming_config (:obj:`riva.client.proto.riva_asr_pb2.StreamingRecognitionConfig`): a config for streaming.
55+
You may find description of config fields in message ``StreamingRecognitionConfig`` in
56+
`common repo
57+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-asr-proto>`_.
58+
An example of creation of streaming config:
59+
60+
.. code-style:: python
61+
62+
from riva.client import RecognitionConfig, StreamingRecognitionConfig
63+
config = RecognitionConfig(enable_automatic_punctuation=True)
64+
streaming_config = StreamingRecognitionConfig(config, interim_results=True)
65+
66+
Yields:
67+
:obj:`riva.client.proto.riva_asr_pb2.StreamingRecognizeResponse`: responses for audio chunks in
68+
:param:`audio_chunks`. You may find description of response fields in declaration of
69+
``StreamingRecognizeResponse``
70+
message `here
71+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-asr-proto>`_.
72+
"""
73+
generator = streaming_s2s_request_generator(audio_chunks, streaming_config)
74+
for response in self.stub.StreamingTranslateSpeechToSpeech(generator, metadata=self.auth.get_auth_metadata()):
75+
yield response
76+
2877
def translate(
2978
self,
3079
texts: List[str],

scripts/nmt/s2s_mic.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: MIT
3+
4+
import argparse
5+
import wave
6+
import riva.client
7+
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
8+
from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
9+
import riva.client.audio_io
10+
import riva.client.proto.riva_nmt_pb2 as riva_nmt
11+
12+
def parse_args() -> argparse.Namespace:
13+
default_device_info = riva.client.audio_io.get_default_input_device_info()
14+
default_device_index = None if default_device_info is None else default_device_info['index']
15+
parser = argparse.ArgumentParser(
16+
description="Streaming speech to speech translation from microphone via Riva AI Services",
17+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18+
)
19+
parser.add_argument("--input-device", type=int, default=default_device_index, help="An input audio device to use.")
20+
parser.add_argument("--list-input-devices", action="store_true", help="List input audio device indices.")
21+
parser.add_argument("--list-output-devices", action="store_true", help="List input audio device indices.")
22+
parser.add_argument("--output-device", type=int, help="Output device to use.")
23+
parser.add_argument(
24+
"--play-audio",
25+
action="store_true",
26+
help="Play input audio simultaneously with transcribing and translating it. If `--output-device` is not provided, "
27+
"then the default output audio device will be used.",
28+
)
29+
30+
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
31+
parser = add_connection_argparse_parameters(parser)
32+
parser.add_argument(
33+
"--sample-rate-hz",
34+
type=int,
35+
help="A number of frames per second in audio streamed from a microphone.",
36+
default=16000,
37+
)
38+
parser.add_argument(
39+
"--file-streaming-chunk",
40+
type=int,
41+
default=1600,
42+
help="A maximum number of frames in a audio chunk sent to server.",
43+
)
44+
args = parser.parse_args()
45+
return args
46+
47+
def play_responses(responses: Iterable[riva_nmt.StreamingTranslateSpeechToSpeechResponse],
48+
sound_stream) -> None:
49+
count = 0
50+
for response in responses:
51+
#if first:
52+
#print(f"time to first audio {(stop - start):.3f}s")
53+
# first=False
54+
if sound_stream is not None:
55+
sound_stream(response.speech.audio)
56+
fname = "response" + str(count)
57+
out_f = wave.open(fname, 'wb')
58+
out_f.setnchannels(1)
59+
out_f.setsampwidth(2)
60+
out_f.setframerate(44100)
61+
count += 1
62+
63+
64+
def main() -> None:
65+
args = parse_args()
66+
sound_stream = None
67+
sampwidth = 2
68+
nchannels = 1
69+
if args.list_input_devices:
70+
riva.client.audio_io.list_input_devices()
71+
return
72+
if args.output_device is not None or args.play_audio:
73+
print("playing audio")
74+
sound_stream = riva.client.audio_io.SoundCallBack(
75+
args.output_device, nchannels=nchannels, sampwidth=sampwidth, framerate=44100
76+
)
77+
print(sound_stream)
78+
first = True # first tts output chunk received
79+
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server)
80+
nmt_service = riva.client.NeuralMachineTranslationClient(auth)
81+
s2s_config = riva.client.StreamingTranslateSpeechToSpeechConfig(
82+
asrConfig = riva.client.StreamingRecognitionConfig(
83+
config=riva.client.RecognitionConfig(
84+
encoding=riva.client.AudioEncoding.LINEAR_PCM,
85+
language_code=args.language_code,
86+
max_alternatives=1,
87+
profanity_filter=args.profanity_filter,
88+
enable_automatic_punctuation=args.automatic_punctuation,
89+
verbatim_transcripts=not args.no_verbatim_transcripts,
90+
sample_rate_hertz=args.sample_rate_hz,
91+
audio_channel_count=1,
92+
),
93+
interim_results=True,
94+
)
95+
)
96+
97+
#riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
98+
with riva.client.audio_io.MicrophoneStream(
99+
args.sample_rate_hz,
100+
args.file_streaming_chunk,
101+
device=args.input_device,
102+
) as audio_chunk_iterator:
103+
play_responses(responses=nmt_service.streaming_s2s_response_generator(
104+
audio_chunks=audio_chunk_iterator,
105+
streaming_config=s2s_config), sound_stream=sound_stream)
106+
# if first:
107+
# first = False
108+
# if sound_stream is not None:
109+
# sound_stream(response.audio)
110+
111+
112+
113+
if __name__ == '__main__':
114+
main()

0 commit comments

Comments
 (0)