Skip to content

Commit 232c737

Browse files
Release 2.19.0 (#122)
* Add speaker diarization to streaming ASR clients (#116) * enable speaker diarization for streaming_asr_client * add: speaker diarization to transcribe_file * fix: print confidence when word offsets are disabled * Fix output format of transcribe_file.py client * updates: add options to ASR and TTS clients (#118) - Add/update list model option to ASR clients - Add encoding option to TTS client --------- Co-authored-by: Viraj Karandikar <[email protected]>
1 parent 943aa96 commit 232c737

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

riva/client/asr.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def print_streaming(
183183
word_time_offsets: bool = False,
184184
show_intermediate: bool = False,
185185
file_mode: str = 'w',
186+
speaker_diarization: bool = False,
186187
) -> None:
187188
"""
188189
Prints streaming speech recognition results to provided files or streams.
@@ -284,12 +285,21 @@ def print_streaming(
284285
if word_time_offsets:
285286
for f in output_file:
286287
f.write("Timestamps:\n")
287-
f.write('{: <40s}{: <16s}{: <16s}\n'.format('Word', 'Start (ms)', 'End (ms)'))
288+
temp = '{: <40s}{: <16s}{: <16s}'
289+
value = ['Word', 'Start (ms)', 'End (ms)']
290+
if speaker_diarization:
291+
temp += '{: <16s}'
292+
value.append('Speaker')
293+
temp += '\n'
294+
f.write(temp.format(*value))
288295
for word_info in result.alternatives[0].words:
289296
f.write(
290297
f'{word_info.word: <40s}{word_info.start_time: <16.0f}'
291-
f'{word_info.end_time: <16.0f}\n'
298+
f'{word_info.end_time: <16.0f}'
292299
)
300+
if speaker_diarization:
301+
f.write(f'{word_info.speaker_tag: <16d}')
302+
f.write('\n')
293303
else:
294304
partial_transcript += transcript
295305
else: # additional_info == 'confidence'

scripts/asr/riva_streaming_asr_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def streaming_transcription_worker(
6060
profanity_filter=args.profanity_filter,
6161
enable_automatic_punctuation=args.automatic_punctuation,
6262
verbatim_transcripts=not args.no_verbatim_transcripts,
63-
enable_word_time_offsets=args.word_time_offsets,
63+
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
6464
),
6565
interim_results=True,
6666
)
@@ -78,6 +78,7 @@ def streaming_transcription_worker(
7878
args.custom_configuration
7979
)
8080
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
81+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
8182
for _ in range(args.num_iterations):
8283
with riva.client.AudioChunkFileIterator(
8384
args.input_file,
@@ -92,7 +93,8 @@ def streaming_transcription_worker(
9293
output_file=output_file,
9394
additional_info='time',
9495
file_mode='a',
95-
word_time_offsets=args.word_time_offsets,
96+
word_time_offsets=args.word_time_offsets or args.speaker_diarization,
97+
speaker_diarization=args.speaker_diarization,
9698
)
9799
except BaseException as e:
98100
exception_queue.put((e, thread_i))

scripts/asr/transcribe_file.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def parse_args() -> argparse.Namespace:
5050
"normal speech.",
5151
)
5252
parser.add_argument(
53-
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript."
53+
"--print-confidence", action="store_true", help="Whether to print stability and confidence of transcript. If `--word-time-offsets` or `--speaker-diarization` is set, then confidence is not printed."
5454
)
5555
parser = add_connection_argparse_parameters(parser)
5656
parser = add_asr_config_argparse_parameters(parser, max_alternatives=True, profanity_filter=True, word_time_offsets=True)
@@ -97,10 +97,12 @@ def main() -> None:
9797
profanity_filter=args.profanity_filter,
9898
enable_automatic_punctuation=args.automatic_punctuation,
9999
verbatim_transcripts=not args.no_verbatim_transcripts,
100+
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
100101
),
101102
interim_results=True,
102103
)
103104
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
105+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
104106
riva.client.add_endpoint_parameters_to_config(
105107
config,
106108
args.start_history,
@@ -133,7 +135,9 @@ def main() -> None:
133135
streaming_config=config,
134136
),
135137
show_intermediate=args.show_intermediate,
136-
additional_info="confidence" if args.print_confidence else "no",
138+
additional_info="time" if (args.word_time_offsets or args.speaker_diarization) else ("confidence" if args.print_confidence else "no"),
139+
word_time_offsets=args.word_time_offsets or args.speaker_diarization,
140+
speaker_diarization=args.speaker_diarization,
137141
)
138142
finally:
139143
if sound_callback is not None and sound_callback.opened:

0 commit comments

Comments
 (0)