Skip to content

Commit 2086d96

Browse files
asr: minor updates to argument names and logging
1 parent 153ebf0 commit 2086d96

File tree

4 files changed

+34
-21
lines changed

4 files changed

+34
-21
lines changed

riva/client/argparse_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def add_asr_config_argparse_parameters(
3838
help="If specified, text inverse normalization will be applied",
3939
)
4040
parser.add_argument("--language-code", default="en-US", help="Language code of the model to be used.")
41+
parser.add_argument("--model-name", default=None, help="Name of the model to be used to be used.")
4142
parser.add_argument("--boosted-lm-words", action='append', help="Words to boost when decoding.")
4243
parser.add_argument(
4344
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."

riva/client/asr.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ class AudioChunkFileIterator:
4646
def __init__(
4747
self,
4848
input_file: Union[str, os.PathLike],
49-
chunk_n_frames: int,
49+
chunk_duration_ms: int,
5050
delay_callback: Optional[Callable[[bytes, float], None]] = None,
5151
) -> None:
5252
self.input_file: Path = Path(input_file).expanduser()
53-
self.chunk_n_frames = chunk_n_frames
53+
self.chunk_duration_ms = chunk_duration_ms
5454
self.delay_callback = delay_callback
5555
self.file_parameters = get_wav_file_parameters(self.input_file)
5656
self.file_object: Optional[typing.BinaryIO] = open(str(self.input_file), 'rb')
@@ -75,9 +75,11 @@ def __iter__(self):
7575

7676
def __next__(self) -> bytes:
7777
if self.file_parameters:
78-
data = self.file_object.read(self.chunk_n_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'])
78+
num_frames = int(self.chunk_duration_ms * self.file_parameters['framerate'] / 1000)
79+
data = self.file_object.read(num_frames * self.file_parameters['sampwidth'] * self.file_parameters['nchannels'])
7980
else:
80-
data = self.file_object.read(self.chunk_n_frames)
81+
# Fixed chunk size when file_parameters is not available
82+
data = self.file_object.read(8192)
8183
if not data:
8284
self.close()
8385
raise StopIteration
@@ -129,6 +131,7 @@ def add_speaker_diarization_to_config(
129131

130132
def print_streaming(
131133
responses: Iterable[rasr.StreamingRecognizeResponse],
134+
input_file: str,
132135
output_file: Optional[Union[Union[os.PathLike, str, TextIO], List[Union[os.PathLike, str, TextIO]]]] = None,
133136
additional_info: str = 'no',
134137
word_time_offsets: bool = False,
@@ -194,6 +197,9 @@ def print_streaming(
194197
output_file[i] = Path(elem).expanduser().open(file_mode)
195198
start_time = time.time() # used in 'time` additional_info
196199
num_chars_printed = 0 # used in 'no' additional_info
200+
final_transcript = "" # for printing best final transcript
201+
for f in output_file:
202+
f.write(f"File: {input_file}\n")
197203
for response in responses:
198204
if not response.results:
199205
continue
@@ -204,6 +210,7 @@ def print_streaming(
204210
transcript = result.alternatives[0].transcript
205211
if additional_info == 'no':
206212
if result.is_final:
213+
final_transcript += transcript
207214
if show_intermediate:
208215
overwrite_chars = ' ' * (num_chars_printed - len(transcript))
209216
for i, f in enumerate(output_file):
@@ -221,6 +228,7 @@ def print_streaming(
221228
partial_transcript += transcript
222229
elif additional_info == 'time':
223230
if result.is_final:
231+
final_transcript += transcript
224232
for i, alternative in enumerate(result.alternatives):
225233
for f in output_file:
226234
f.write(
@@ -239,6 +247,7 @@ def print_streaming(
239247
partial_transcript += transcript
240248
else: # additional_info == 'confidence'
241249
if result.is_final:
250+
final_transcript += transcript
242251
for f in output_file:
243252
f.write(f'## {transcript}\n')
244253
f.write(f'Confidence: {result.alternatives[0].confidence:9.4f}\n')
@@ -259,6 +268,9 @@ def print_streaming(
259268
else:
260269
for f in output_file:
261270
f.write('----\n')
271+
for f in output_file:
272+
f.write(f"Final transcripts:\n")
273+
f.write(f"0 : {final_transcript}\n")
262274
finally:
263275
for fo, elem in zip(file_opened, output_file):
264276
if fo:

scripts/asr/riva_streaming_asr_client.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace:
2323
"which names follow a format `output_<thread_num>.txt`.",
2424
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
2525
)
26-
parser.add_argument("--num-clients", default=1, type=int, help="Number of client threads.")
26+
parser.add_argument("--num-parallel-requests", default=1, type=int, help="Number of client threads.")
2727
parser.add_argument("--num-iterations", default=1, type=int, help="Number of iterations over the file.")
2828
parser.add_argument(
2929
"--input-file", required=True, type=str, help="Name of the WAV file with LINEAR_PCM encoding to transcribe."
@@ -35,7 +35,10 @@ def parse_args() -> argparse.Namespace:
3535
"normal speech.",
3636
)
3737
parser.add_argument(
38-
"--file-streaming-chunk", type=int, default=1600, help="Number of frames in one chunk sent to server."
38+
"--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds."
39+
)
40+
parser.add_argument(
41+
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
3942
)
4043
parser = add_connection_argparse_parameters(parser)
4144
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
@@ -61,23 +64,22 @@ def streaming_transcription_worker(
6164
verbatim_transcripts=not args.no_verbatim_transcripts,
6265
enable_word_time_offsets=args.word_time_offsets,
6366
),
64-
interim_results=True,
67+
interim_results=args.interim_results,
6568
)
6669
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
6770
for _ in range(args.num_iterations):
6871
with riva.client.AudioChunkFileIterator(
6972
args.input_file,
70-
args.file_streaming_chunk,
73+
args.chunk_duration_ms,
7174
delay_callback=riva.client.sleep_audio_length if args.simulate_realtime else None,
7275
) as audio_chunk_iterator:
7376
riva.client.print_streaming(
7477
responses=asr_service.streaming_response_generator(
7578
audio_chunks=audio_chunk_iterator,
7679
streaming_config=config,
7780
),
81+
input_file=args.input_file,
7882
output_file=output_file,
79-
additional_info='time',
80-
file_mode='a',
8183
word_time_offsets=args.word_time_offsets,
8284
)
8385
except BaseException as e:
@@ -87,12 +89,12 @@ def streaming_transcription_worker(
8789

8890
def main() -> None:
8991
args = parse_args()
90-
print("Number of clients:", args.num_clients)
92+
print("Number of clients:", args.num_parallel_requests)
9193
print("Number of iteration:", args.num_iterations)
9294
print("Input file:", args.input_file)
9395
threads = []
9496
exception_queue = queue.Queue()
95-
for i in range(args.num_clients):
97+
for i in range(args.num_parallel_requests):
9698
t = Thread(target=streaming_transcription_worker, args=[args, f"output_{i:d}.txt", i, exception_queue])
9799
t.start()
98100
threads.append(t)
@@ -112,7 +114,7 @@ def main() -> None:
112114
if all_dead:
113115
break
114116
time.sleep(0.05)
115-
print(str(args.num_clients), "threads done, output written to output_<thread_id>.txt")
117+
print(str(args.num_parallel_requests), "threads done, output written to output_<thread_id>.txt")
116118

117119

118120
if __name__ == "__main__":

scripts/asr/transcribe_file.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace:
1818
parser.add_argument("--input-file", help="A path to a local file to stream.")
1919
parser.add_argument("--list-devices", action="store_true", help="List output devices indices")
2020
parser.add_argument(
21-
"--show-intermediate", action="store_true", help="Show intermediate transcripts as they are available."
21+
"--interim-results", default=False, action='store_true', help="Print intermediate transcripts",
2222
)
2323
parser.add_argument(
2424
"--output-device",
@@ -34,10 +34,7 @@ def parse_args() -> argparse.Namespace:
3434
"then the default output audio device will be used.",
3535
)
3636
parser.add_argument(
37-
"--file-streaming-chunk",
38-
type=int,
39-
default=1600,
40-
help="A maximum number of frames in one chunk sent to server.",
37+
"--chunk-duration-ms", type=int, default=100, help="Chunk duration in milliseconds."
4138
)
4239
parser.add_argument(
4340
"--simulate-realtime",
@@ -76,7 +73,7 @@ def main() -> None:
7673
enable_automatic_punctuation=args.automatic_punctuation,
7774
verbatim_transcripts=not args.no_verbatim_transcripts,
7875
),
79-
interim_results=True,
76+
interim_results=args.interim_results,
8077
)
8178
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
8279
sound_callback = None
@@ -90,14 +87,15 @@ def main() -> None:
9087
else:
9188
delay_callback = riva.client.sleep_audio_length if args.simulate_realtime else None
9289
with riva.client.AudioChunkFileIterator(
93-
args.input_file, args.file_streaming_chunk, delay_callback,
90+
args.input_file, args.chunk_duration_ms, delay_callback,
9491
) as audio_chunk_iterator:
9592
riva.client.print_streaming(
9693
responses=asr_service.streaming_response_generator(
9794
audio_chunks=audio_chunk_iterator,
9895
streaming_config=config,
9996
),
100-
show_intermediate=args.show_intermediate,
97+
input_file=args.input_file,
98+
show_intermediate=args.interim_results,
10199
additional_info="confidence" if args.print_confidence else "no",
102100
)
103101
finally:

0 commit comments

Comments
 (0)