Skip to content

Commit 7b74ab2

Browse files
feat(asr): add speaker diarization support in offline client (#25)
1 parent f5dce10 commit 7b74ab2

File tree

5 files changed

+22
-2
lines changed

5 files changed

+22
-2
lines changed

riva/client/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ASRService,
77
add_audio_file_specs_to_config,
88
add_word_boosting_to_config,
9+
add_speaker_diarization_to_config,
910
get_wav_file_parameters,
1011
print_offline,
1112
print_streaming,

riva/client/argparse_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ def add_asr_config_argparse_parameters(
4242
parser.add_argument(
4343
"--boosted-lm-score", type=float, default=4.0, help="Value by which to boost words when decoding."
4444
)
45+
parser.add_argument(
46+
"--speaker-diarization",
47+
default=False,
48+
action='store_true',
49+
help="Flag that controls if speaker diarization should be performed",
50+
)
4551
return parser
4652

4753

riva/client/asr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,16 @@ def add_audio_file_specs_to_config(
9999
inner_config.audio_channel_count = wav_parameters['nchannels']
100100

101101

102+
def add_speaker_diarization_to_config(
103+
config: Union[rasr.RecognitionConfig],
104+
diarization_enable: bool,
105+
) -> None:
106+
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
107+
if diarization_enable:
108+
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
109+
inner_config.diarization_config.CopyFrom(diarization_config)
110+
111+
102112
PRINT_STREAMING_ADDITIONAL_INFO_MODES = ['no', 'time', 'confidence']
103113

104114

scripts/asr/transcribe_file_offline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def parse_args() -> argparse.Namespace:
1818
)
1919
parser.add_argument("--input-file", required=True, type=Path, help="A path to a local file to transcribe.")
2020
parser = add_connection_argparse_parameters(parser)
21-
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True)
21+
parser = add_asr_config_argparse_parameters(parser, profanity_filter=True, word_time_offsets=True)
2222
args = parser.parse_args()
2323
args.input_file = args.input_file.expanduser()
2424
return args
@@ -35,9 +35,12 @@ def main() -> None:
3535
profanity_filter=args.profanity_filter,
3636
enable_automatic_punctuation=args.automatic_punctuation,
3737
verbatim_transcripts=not args.no_verbatim_transcripts,
38+
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
3839
)
3940
riva.client.add_audio_file_specs_to_config(config, args.input_file)
4041
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
42+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
43+
4144
with args.input_file.open('rb') as fh:
4245
data = fh.read()
4346
try:

0 commit comments

Comments
 (0)