Skip to content

Commit 22438ab

Browse files
add s2s and s2t client utility functions (#43)
1 parent 2d3a719 commit 22438ab

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

riva/client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,6 @@
3636
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
3737
from riva.client.proto.riva_audio_pb2 import AudioEncoding
3838
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
39+
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig
3940
from riva.client.tts import SpeechSynthesisService
4041
from riva.client.nmt import NeuralMachineTranslationClient

riva/client/nmt.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: MIT
33

4-
from typing import Generator, Optional, Union, List
5-
4+
from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
65
from grpc._channel import _MultiThreadedRendezvous
76

87
import riva.client.proto.riva_nmt_pb2 as riva_nmt
98
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
109
from riva.client import Auth
1110

11+
def streaming_s2s_request_generator(
12+
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
13+
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechRequest, None, None]:
14+
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(config=streaming_config)
15+
for chunk in audio_chunks:
16+
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(audio_content=chunk)
17+
18+
def streaming_s2t_request_generator(
19+
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
20+
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextRequest, None, None]:
21+
yield riva_nmt.StreamingTranslateSpeechToTextRequest(config=streaming_config)
22+
for chunk in audio_chunks:
23+
yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk)
1224

1325
class NeuralMachineTranslationClient:
1426
"""
@@ -25,6 +37,99 @@ def __init__(self, auth: Auth) -> None:
2537
self.auth = auth
2638
self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel)
2739

40+
def streaming_s2s_response_generator(
41+
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
42+
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechResponse, None, None]:
43+
"""
44+
Generates speech to speech translation responses for fragments of speech audio in :param:`audio_chunks`.
45+
The purpose of the method is to perform speech to speech translation "online" - as soon as
46+
audio is acquired on small chunks of audio.
47+
48+
All available audio chunks will be sent to a server on first ``next()`` call.
49+
50+
Args:
51+
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
52+
of speech. For example, such raw audio can be obtained with
53+
54+
.. code-block:: python
55+
56+
import wave
57+
with wave.open(file_name, 'rb') as wav_f:
58+
raw_audio = wav_f.readframes(n_frames)
59+
60+
streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig`): a config for streaming.
61+
You may find description of config fields in message ``StreamingTranslateSpeechToSpeechConfig`` in
62+
`common repo
63+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
64+
An example of creation of streaming config:
65+
66+
.. code-style:: python
67+
68+
from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig
69+
config = RecognitionConfig(enable_automatic_punctuation=True)
70+
asr_config = StreamingRecognitionConfig(config, interim_results=True)
71+
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
72+
tts_config = SynthesizeSpeechConfig(sample_rate_hz=44100, voice_name="English-US.Female-1")
73+
streaming_config = StreamingTranslateSpeechToSpeechConfig(asr_config, translation_config, tts_config)
74+
75+
Yields:
76+
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechResponse`: responses for audio chunks in
77+
:param:`audio_chunks`. You may find description of response fields in declaration of
78+
``StreamingTranslateSpeechToSpeechResponse``
79+
message `here
80+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
81+
"""
82+
generator = streaming_s2s_request_generator(audio_chunks, streaming_config)
83+
for response in self.stub.StreamingTranslateSpeechToSpeech(generator, metadata=self.auth.get_auth_metadata()):
84+
yield response
85+
86+
87+
def streaming_s2t_response_generator(
88+
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
89+
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextResponse, None, None]:
90+
"""
91+
Generates speech to text translation responses for fragments of speech audio in :param:`audio_chunks`.
92+
The purpose of the method is to perform speech to text translation "online" - as soon as
93+
audio is acquired on small chunks of audio.
94+
95+
All available audio chunks will be sent to a server on first ``next()`` call.
96+
97+
Args:
98+
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
99+
of speech. For example, such raw audio can be obtained with
100+
101+
.. code-block:: python
102+
103+
import wave
104+
with wave.open(file_name, 'rb') as wav_f:
105+
raw_audio = wav_f.readframes(n_frames)
106+
107+
streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextConfig`): a config for streaming.
108+
You may find description of config fields in message ``StreamingTranslateSpeechToTextConfig`` in
109+
`common repo
110+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
111+
An example of creation of streaming config:
112+
113+
.. code-style:: python
114+
115+
from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToTextConfig, TranslationConfig
116+
config = RecognitionConfig(enable_automatic_punctuation=True)
117+
asr_config = StreamingRecognitionConfig(config, interim_results=True)
118+
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
119+
streaming_config = StreamingTranslateSpeechToTextConfig(asr_config, translation_config)
120+
121+
Yields:
122+
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextResponse`: responses for audio chunks in
123+
:param:`audio_chunks`. You may find description of response fields in declaration of
124+
``StreamingTranslateSpeechToTextResponse``
125+
message `here
126+
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
127+
"""
128+
generator = streaming_s2t_request_generator(audio_chunks, streaming_config)
129+
for response in self.stub.StreamingTranslateSpeechToText(generator, metadata=self.auth.get_auth_metadata()):
130+
yield response
131+
132+
28133
def translate(
29134
self,
30135
texts: List[str],

0 commit comments

Comments
 (0)