Skip to content

Commit 9a2cd82

Browse files
authored
Support setting max speakers for offline diarization (#97)
* fix: accept input for max_speaker_count in asr/transcribe_file_offline * fix: rename input field to diarization_max_speakers * remove: redundant default value for max_speakers
1 parent b94b3a9 commit 9a2cd82

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

riva/client/argparse_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def add_asr_config_argparse_parameters(
4949
action='store_true',
5050
help="Flag that controls if speaker diarization should be performed",
5151
)
52+
parser.add_argument(
53+
"--diarization-max-speakers",
54+
default=3,
55+
type=int,
56+
help="Max number of speakers to detect when performing speaker diarization",
57+
)
5258
parser.add_argument(
5359
"--start-history",
5460
default=-1,

riva/client/asr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,14 @@ def add_audio_file_specs_to_config(
117117
def add_speaker_diarization_to_config(
118118
config: Union[rasr.RecognitionConfig],
119119
diarization_enable: bool,
120+
diarization_max_speakers: int,
120121
) -> None:
121122
inner_config: rasr.RecognitionConfig = config if isinstance(config, rasr.RecognitionConfig) else config.config
122123
if diarization_enable:
123-
diarization_config = rasr.SpeakerDiarizationConfig(enable_speaker_diarization=True)
124+
diarization_config = rasr.SpeakerDiarizationConfig(
125+
enable_speaker_diarization=True,
126+
max_speaker_count=diarization_max_speakers,
127+
)
124128
inner_config.diarization_config.CopyFrom(diarization_config)
125129

126130

scripts/asr/transcribe_file_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def main() -> None:
3737
enable_word_time_offsets=args.word_time_offsets or args.speaker_diarization,
3838
)
3939
riva.client.add_word_boosting_to_config(config, args.boosted_lm_words, args.boosted_lm_score)
40-
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization)
40+
riva.client.add_speaker_diarization_to_config(config, args.speaker_diarization, args.diarization_max_speakers)
4141
riva.client.add_endpoint_parameters_to_config(
4242
config,
4343
args.start_history,

0 commit comments

Comments
 (0)