1
1
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
2
# SPDX-License-Identifier: MIT
3
3
4
- from typing import Generator , Optional , Union , List
5
-
4
+ from typing import Callable , Dict , Generator , Iterable , List , Optional , TextIO , Union
6
5
from grpc ._channel import _MultiThreadedRendezvous
7
6
8
7
import riva .client .proto .riva_nmt_pb2 as riva_nmt
9
8
import riva .client .proto .riva_nmt_pb2_grpc as riva_nmt_srv
10
9
from riva .client import Auth
11
10
11
+ def streaming_s2s_request_generator (
12
+ audio_chunks : Iterable [bytes ], streaming_config : riva_nmt .StreamingTranslateSpeechToSpeechConfig
13
+ ) -> Generator [riva_nmt .StreamingTranslateSpeechToSpeechRequest , None , None ]:
14
+ yield riva_nmt .StreamingTranslateSpeechToSpeechRequest (config = streaming_config )
15
+ for chunk in audio_chunks :
16
+ yield riva_nmt .StreamingTranslateSpeechToSpeechRequest (audio_content = chunk )
17
+
18
+ def streaming_s2t_request_generator (
19
+ audio_chunks : Iterable [bytes ], streaming_config : riva_nmt .StreamingTranslateSpeechToTextConfig
20
+ ) -> Generator [riva_nmt .StreamingTranslateSpeechToTextRequest , None , None ]:
21
+ yield riva_nmt .StreamingTranslateSpeechToTextRequest (config = streaming_config )
22
+ for chunk in audio_chunks :
23
+ yield riva_nmt .StreamingTranslateSpeechToTextRequest (audio_content = chunk )
12
24
13
25
class NeuralMachineTranslationClient :
14
26
"""
@@ -25,6 +37,99 @@ def __init__(self, auth: Auth) -> None:
25
37
self .auth = auth
26
38
self .stub = riva_nmt_srv .RivaTranslationStub (self .auth .channel )
27
39
40
+ def streaming_s2s_response_generator (
41
+ self , audio_chunks : Iterable [bytes ], streaming_config : riva_nmt .StreamingTranslateSpeechToSpeechConfig
42
+ ) -> Generator [riva_nmt .StreamingTranslateSpeechToSpeechResponse , None , None ]:
43
+ """
44
+ Generates speech to speech translation responses for fragments of speech audio in :param:`audio_chunks`.
45
+ The purpose of the method is to perform speech to speech translation "online" - as soon as
46
+ audio is acquired on small chunks of audio.
47
+
48
+ All available audio chunks will be sent to a server on first ``next()`` call.
49
+
50
+ Args:
51
+ audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
52
+ of speech. For example, such raw audio can be obtained with
53
+
54
+ .. code-block:: python
55
+
56
+ import wave
57
+ with wave.open(file_name, 'rb') as wav_f:
58
+ raw_audio = wav_f.readframes(n_frames)
59
+
60
+ streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig`): a config for streaming.
61
+ You may find description of config fields in message ``StreamingTranslateSpeechToSpeechConfig`` in
62
+ `common repo
63
+ <https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
64
+ An example of creation of streaming config:
65
+
66
+ .. code-style:: python
67
+
68
+ from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig
69
+ config = RecognitionConfig(enable_automatic_punctuation=True)
70
+ asr_config = StreamingRecognitionConfig(config, interim_results=True)
71
+ translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
72
+ tts_config = SynthesizeSpeechConfig(sample_rate_hz=44100, voice_name="English-US.Female-1")
73
+ streaming_config = StreamingTranslateSpeechToSpeechConfig(asr_config, translation_config, tts_config)
74
+
75
+ Yields:
76
+ :obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechResponse`: responses for audio chunks in
77
+ :param:`audio_chunks`. You may find description of response fields in declaration of
78
+ ``StreamingTranslateSpeechToSpeechResponse``
79
+ message `here
80
+ <https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
81
+ """
82
+ generator = streaming_s2s_request_generator (audio_chunks , streaming_config )
83
+ for response in self .stub .StreamingTranslateSpeechToSpeech (generator , metadata = self .auth .get_auth_metadata ()):
84
+ yield response
85
+
86
+
87
+ def streaming_s2t_response_generator (
88
+ self , audio_chunks : Iterable [bytes ], streaming_config : riva_nmt .StreamingTranslateSpeechToTextConfig
89
+ ) -> Generator [riva_nmt .StreamingTranslateSpeechToTextResponse , None , None ]:
90
+ """
91
+ Generates speech to text translation responses for fragments of speech audio in :param:`audio_chunks`.
92
+ The purpose of the method is to perform speech to text translation "online" - as soon as
93
+ audio is acquired on small chunks of audio.
94
+
95
+ All available audio chunks will be sent to a server on first ``next()`` call.
96
+
97
+ Args:
98
+ audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
99
+ of speech. For example, such raw audio can be obtained with
100
+
101
+ .. code-block:: python
102
+
103
+ import wave
104
+ with wave.open(file_name, 'rb') as wav_f:
105
+ raw_audio = wav_f.readframes(n_frames)
106
+
107
+ streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextConfig`): a config for streaming.
108
+ You may find description of config fields in message ``StreamingTranslateSpeechToTextConfig`` in
109
+ `common repo
110
+ <https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
111
+ An example of creation of streaming config:
112
+
113
+ .. code-style:: python
114
+
115
+ from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToTextConfig, TranslationConfig
116
+ config = RecognitionConfig(enable_automatic_punctuation=True)
117
+ asr_config = StreamingRecognitionConfig(config, interim_results=True)
118
+ translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
119
+ streaming_config = StreamingTranslateSpeechToTextConfig(asr_config, translation_config)
120
+
121
+ Yields:
122
+ :obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextResponse`: responses for audio chunks in
123
+ :param:`audio_chunks`. You may find description of response fields in declaration of
124
+ ``StreamingTranslateSpeechToTextResponse``
125
+ message `here
126
+ <https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
127
+ """
128
+ generator = streaming_s2t_request_generator (audio_chunks , streaming_config )
129
+ for response in self .stub .StreamingTranslateSpeechToText (generator , metadata = self .auth .get_auth_metadata ()):
130
+ yield response
131
+
132
+
28
133
def translate (
29
134
self ,
30
135
texts : List [str ],
0 commit comments