Skip to content

Commit 0532a1f

Browse files
asr: minor updates to argument names and logging
1 parent 153ebf0 commit 0532a1f

File tree

6 files changed

+74
-67
lines changed

6 files changed

+74
-67
lines changed

riva/client/argparse_utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66

77
def add_asr_config_argparse_parameters(
8-
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
8+
parser: argparse.ArgumentParser,
9+
max_alternatives: bool = False,
10+
profanity_filter: bool = False,
11+
word_time_offsets: bool = False,
912
) -> argparse.ArgumentParser:
1013
if word_time_offsets:
1114
parser.add_argument(
@@ -20,11 +23,11 @@ def add_asr_config_argparse_parameters(
2023
)
2124
if profanity_filter:
2225
parser.add_argument(
23-
"--profanity-filter",
24-
default=False,
25-
action='store_true',
26-
help="Flag that controls the profanity filtering in the generated transcripts",
27-
)
26+
"--profanity-filter",
27+
default=False,
28+
action='store_true',
29+
help="Flag that controls the profanity filtering in the generated transcripts",
30+
)
2831
parser.add_argument(
2932
"--automatic-punctuation",
3033
default=False,
@@ -38,6 +41,7 @@ def add_asr_config_argparse_parameters(
3841
help="If specified, text inverse normalization will be applied",
3942
)
4043
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
44+
parser.add_argument("--model-name", default="", help="Name of the model to be used to be used.")
4145
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
4246
parser.add_argument(
4347
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."

riva/client/asr.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_wav_file_parameters(input_file: Union[str, os.PathLike]) -> Dict[str, Un
3030
'duration': nframes / rate,
3131
'nchannels': wf.getnchannels(),
3232
'sampwidth': wf.getsampwidth(),
33-
'data_offset': wf.getfp().size_read + wf.getfp().offset
33+
'data_offset': wf.getfp().size_read + wf.getfp().offset,
3434
}
3535
except:
3636
# Not a WAV file
@@ -46,11 +46,11 @@ class AudioChunkFileIterator:
4646
def __init__(
4747
self,
4848
input_file: Union[str, os.PathLike],
49-
chunk_n_frames: int,
49+
chunk_duration_ms: int,
5050
delay_callback: Optional[Callable[[bytes, float], None]] = None,
5151
) -> None:
5252
self.input_file: Path = Path(input_file).expanduser()
53-
self.chunk_n_frames = chunk_n_frames
53+
self.chunk_duration_ms = chunk_duration_ms
5454
self.delay_callback = delay_callback
5555
self.file_parameters = get_wav_file_parameters(self.input_file)
5656
self.file_object: Optional[typing.BinaryIO] = open(str(self.input_file), 'rb')
@@ -75,16 +75,21 @@ def __iter__(self):
7575

7676
def __next__(self) -> bytes:
7777
if self.file_parameters:
78-
data = self.file_object.read(self.chunk_n_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'])
78+
num_frames = int(self.chunk_duration_ms * self.file_parameters['framerate'] / 1000)
79+
data = self.file_object.read(
80+
num_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels']
81+
)
7982
else:
80-
data = self.file_object.read(self.chunk_n_frames)
83+
# Fixed chunk size when file_parameters is not available
84+
data = self.file_object.read(8192)
8185
if not data:
8286
self.close()
8387
raise StopIteration
8488
if self.delay_callback is not None:
8589
offset = self.file_parameters['data_offset'] if self.first_buffer else 0
8690
self.delay_callback(
87-
data[offset:], (len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate']
91+
data[offset:],
92+
(len(data) - offset) / self.file_parameters['sampwidth'] / self.file_parameters['framerate'],
8893
)
8994
self.first_buffer = False
9095
return data
@@ -104,8 +109,7 @@ def add_word_boosting_to_config(
104109

105110

106111
def add_audio_file_specs_to_config(
107-
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig],
108-
audio_file: Union[str, os.PathLike],
112+
config: Union[rasr.StreamingRecognitionConfig, rasr.RecognitionConfig], audio_file: Union[str, os.PathLike],
109113
) -> None:
110114
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
111115
wav_parameters = get_wav_file_parameters(audio_file)
@@ -114,10 +118,7 @@ def add_audio_file_specs_to_config(
114118
inner_config.audio_channel_count = wav_parameters['nchannels']
115119

116120

117-
def add_speaker_diarization_to_config(
118-
config: Union[rasr.RecognitionConfig],
119-
diarization_enable: bool,
120-
) -> None:
121+
def add_speaker_diarization_to_config(config: Union[rasr.RecognitionConfig], diarization_enable: bool,) -> None:
121122
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
122123
if diarization_enable:
123124
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
@@ -129,6 +130,7 @@ def add_speaker_diarization_to_config(
129130

130131
def print_streaming(
131132
responses: Iterable[rasr.StreamingRecognizeResponse],
133+
input_file: str = None,
132134
output_file: Optional[Union[Union[os.PathLike, str, TextIO], List[Union[os.PathLike, str, TextIO]]]] = None,
133135
additional_info: str = 'no',
134136
word_time_offsets: bool = False,
@@ -194,6 +196,10 @@ def print_streaming(
194196
output_file[i] = Path(elem).expanduser().open(file_mode)
195197
start_time = time.time() # used in 'time` additional_info
196198
num_chars_printed = 0 # used in 'no' additional_info
199+
final_transcript = "" # for printing best final transcript
200+
if input_file:
201+
for f in output_file:
202+
f.write(f"File: {input_file}\n")
197203
for response in responses:
198204
if not response.results:
199205
continue
@@ -204,6 +210,7 @@ def print_streaming(
204210
transcript = result.alternatives[0].transcript
205211
if additional_info == 'no':
206212
if result.is_final:
213+
final_transcript += transcript
207214
if show_intermediate:
208215
overwrite_chars = ' ' * (num_chars_printed - len(transcript))
209216
for i, f in enumerate(output_file):
@@ -221,6 +228,7 @@ def print_streaming(
221228
partial_transcript += transcript
222229
elif additional_info == 'time':
223230
if result.is_final:
231+
final_transcript += transcript
224232
for i, alternative in enumerate(result.alternatives):
225233
for f in output_file:
226234
f.write(
@@ -239,6 +247,7 @@ def print_streaming(
239247
partial_transcript += transcript
240248
else: # additional_info == 'confidence'
241249
if result.is_final:
250+
final_transcript += transcript
242251
for f in output_file:
243252
f.write(f'## {transcript}\n')
244253
f.write(f'Confidence: {result.alternatives[0].confidence:9.4f}\n')
@@ -259,6 +268,9 @@ def print_streaming(
259268
else:
260269
for f in output_file:
261270
f.write('----\n')
271+
for f in output_file:
272+
f.write(f"Final transcripts:\n")
273+
f.write(f"0 : {final_transcript}\n")
262274
finally:
263275
for fo, elem in zip(file_opened, output_file):
264276
if fo:
@@ -284,6 +296,7 @@ def streaming_request_generator(
284296

285297
class ASRService:
286298
"""Provides streaming and offline recognition services. Calls gRPC stubs with authentication metadata."""
299+
287300
def __init__(self, auth: Auth) -> None:
288301
"""
289302
Initializes an instance of the class.

riva/client/audio_io.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
# SPDX-License-Identifier: MIT
33

44
import queue
5-
from typing import Dict, Union, Optional
5+
from typing import Dict, Optional, Union
66

77
import pyaudio
88

99

1010
class MicrophoneStream:
1111
"""Opens a recording stream as responses yielding the audio chunks."""
1212

13-
def __init__(self, rate: int, chunk: int, device: int = None) -> None:
13+
def __init__(self, rate: int, chunk_duration_ms: int, device: int = None) -> None:
1414
self._rate = rate
15-
self._chunk = chunk
15+
self._chunk = int(chunk_duration_ms * rate / 1000)
1616
self._device = device
1717

1818
# Create a thread-safe buffer of audio data
@@ -115,9 +115,7 @@ def list_input_devices() -> None:
115115

116116

117117
class SoundCallBack:
118-
def __init__(
119-
self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,
120-
) -> None:
118+
def __init__(self, output_device_index: Optional[int], sampwidth: int, nchannels: int, framerate: int,) -> None:
121119
self.pa = pyaudio.PyAudio()
122120
self.stream = self.pa.open(
123121
output_device_index=output_device_index,

scripts/asr/riva_streaming_asr_client.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from typing import Union
1111

1212
import riva.client
13-
from riva.client.asr import get_wav_file_parameters
1413
from riva.client.argparse_utils import add_asr_config_argparse_parameters, add_connection_argparse_parameters
14+
from riva.client.asr import get_wav_file_parameters
1515

1616

1717
def parse_args() -> argparse.Namespace:
@@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace:
2323
"which names follow a format `output_<thread_num>.txt`.",
2424
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
2525
)
26-
parser.add_argument("--num-clients", default=1, type=int, help="Number of client threads.")
26+
parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.")
2727
parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.")
2828
parser.add_argument(
2929
"--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
@@ -34,11 +34,14 @@ def parse_args() -> argparse.Namespace:
3434
help="Option to simulate realtime transcription. Audio fragments are sent to a server at a pace that mimics "
3535
"normal speech.",
3636
)
37+
parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.")
3738
parser.add_argument(
38-
"--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server."
39+
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
3940
)
4041
parser = add_connection_argparse_parameters(parser)
41-
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
42+
parser = add_asr_config_argparse_parameters(
43+
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
44+
)
4245
args = parser.parse_args()
4346
if args.max_alternatives < 1:
4447
parser.error("`--max-alternatives` must be greater than or equal to 1")
@@ -60,24 +63,23 @@ def streaming_transcription_worker(
6063
enable_automatic_punctuation=args.automatic_punctuation,
6164
verbatim_transcripts=not args.no_verbatim_transcripts,
6265
enable_word_time_offsets=args.word_time_offsets,
66+
model=args.model_name,
6367
),
64-
interim_results=True,
68+
interim_results=args.interim_results,
6569
)
6670
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
6771
for _ in range(args.num_iterations):
6872
with riva.client.AudioChunkFileIterator(
6973
args.input_file,
70-
args.file_streaming_chunk,
74+
args.chunk_duration_ms,
7175
delay_callback=riva.client.sleep_audio_length if args.simulate_realtime else None,
7276
) as audio_chunk_iterator:
7377
riva.client.print_streaming(
7478
responses=asr_service.streaming_response_generator(
75-
audio_chunks=audio_chunk_iterator,
76-
streaming_config=config,
79+
audio_chunks=audio_chunk_iterator, streaming_config=config,
7780
),
81+
input_file=args.input_file,
7882
output_file=output_file,
79-
additional_info='time',
80-
file_mode='a',
8183
word_time_offsets=args.word_time_offsets,
8284
)
8385
except BaseException as e:
@@ -87,12 +89,12 @@ def streaming_transcription_worker(
8789

8890
def main() -> None:
8991
args = parse_args()
90-
print("Number of clients:", args.num_clients)
92+
print("Number of clients:", args.num_parallel_requests)
9193
print("Number of iteration:", args.num_iterations)
9294
print("Input file:", args.input_file)
9395
threads = []
9496
exception_queue = queue.Queue()
95-
for i in range(args.num_clients):
97+
for i in range(args.num_parallel_requests):
9698
t = Thread(target=streaming_transcription_worker, args=[args, f"output_{i:d}.txt", i, exception_queue])
9799
t.start()
98100
threads.append(t)
@@ -112,7 +114,8 @@ def main() -> None:
112114
if all_dead:
113115
break
114116
time.sleep(0.05)
115-
print(str(args.num_clients), "threads done, output written to output_<thread_id>.txt")
117+
for i in range(args.num_parallel_requests):
118+
print(f"Thread {i} done, output written to output_{i}.txt")
116119

117120

118121
if __name__ == "__main__":

scripts/asr/transcribe_file.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,22 @@ def parse_args() -> argparse.Namespace:
1818
parser.add_argument("--input-file", help="A path to a local file to stream.")
1919
parser.add_argument("--list-devices", action="store_true", help="List output devices indices")
2020
parser.add_argument(
21-
"--show-intermediate", action="store_true", help="Show intermediate transcripts as they are available."
21+
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
2222
)
2323
parser.add_argument(
2424
"--output-device",
2525
type=int,
2626
default=None,
2727
help="Output audio device to use for playing audio simultaneously with transcribing. If this parameter is "
28-
"provided, then you do not have to `--play-audio` option."
28+
"provided, then you do not have to `--play-audio` option.",
2929
)
3030
parser.add_argument(
3131
"--play-audio",
3232
action="store_true",
3333
help="Whether to play input audio simultaneously with transcribing. If `--output-device` is not provided, "
3434
"then the default output audio device will be used.",
3535
)
36-
parser.add_argument(
37-
"--file-streaming-chunk",
38-
type=int,
39-
default=1600,
40-
help="A maximum number of frames in one chunk sent to server.",
41-
)
36+
parser.add_argument("--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds.")
4237
parser.add_argument(
4338
"--simulate-realtime",
4439
action='store_true',
@@ -49,7 +44,9 @@ def parse_args() -> argparse.Namespace:
4944
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript."
5045
)
5146
parser = add_connection_argparse_parameters(parser)
52-
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
47+
parser = add_asr_config_argparse_parameters(
48+
parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True
49+
)
5350
args = parser.parse_args()
5451
if not args.list_devices and args.input_file is None:
5552
parser.error(
@@ -71,12 +68,14 @@ def main() -> None:
7168
config = riva.client.StreamingRecognitionConfig(
7269
config=riva.client.RecognitionConfig(
7370
language_code=args.language_code,
74-
max_alternatives=1,
71+
max_alternatives=args.max_alternatives,
7572
profanity_filter=args.profanity_filter,
7673
enable_automatic_punctuation=args.automatic_punctuation,
7774
verbatim_transcripts=not args.no_verbatim_transcripts,
75+
enable_word_time_offsets=args.word_time_offsets,
76+
model=args.model_name,
7877
),
79-
interim_results=True,
78+
interim_results=args.interim_results,
8079
)
8180
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
8281
sound_callback = None
@@ -90,14 +89,14 @@ def main() -> None:
9089
else:
9190
delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None
9291
with riva.client.AudioChunkFileIterator(
93-
args.input_file, args.file_streaming_chunk, delay_callback,
92+
args.input_file, args.chunk_duration_ms, delay_callback,
9493
) as audio_chunk_iterator:
9594
riva.client.print_streaming(
9695
responses=asr_service.streaming_response_generator(
97-
audio_chunks=audio_chunk_iterator,
98-
streaming_config=config,
96+
audio_chunks=audio_chunk_iterator, streaming_config=config,
9997
),
100-
show_intermediate=args.show_intermediate,
98+
input_file=args.input_file,
99+
show_intermediate=args.interim_results,
101100
additional_info="confidence" if args.print_confidence else "no",
102101
)
103102
finally:

0 commit comments

Comments
 (0)