Skip to content

Commit 4470cc1

Browse files
committed
feat: Add Voice Activity Detection and Speaker Diarization support
- Introduced VAD functionality to filter silent audio regions, improving transcription efficiency. - Added speaker diarization capabilities using pyannote.audio, allowing identification of speakers in multi-speaker audio. - Updated CLI and README to reflect new features and usage examples. - Enhanced transcribe function to support VAD and diarization options. - Implemented RTTM format output for diarization results. Signed-off-by: sealad886 <[email protected]>
1 parent 7ddca42 commit 4470cc1

File tree

8 files changed

+959
-17
lines changed

8 files changed

+959
-17
lines changed

whisper/README.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,76 @@ To see more transcription options use:
8282
>>> help(mlx_whisper.transcribe)
8383
```
8484

85+
### Voice Activity Detection (VAD)
86+
87+
Enable Silero VAD to filter silent audio regions before transcription. This can
88+
significantly speed up transcription for audio with long silent periods:
89+
90+
```bash
91+
# Enable VAD
92+
mlx_whisper audio.mp3 --vad-filter
93+
94+
# Customize VAD settings
95+
mlx_whisper audio.mp3 --vad-filter --vad-threshold 0.6 --vad-min-silence-ms 1000
96+
```
97+
98+
In Python:
99+
100+
```python
101+
from mlx_whisper import transcribe
102+
from mlx_whisper.vad import VadOptions
103+
104+
result = transcribe("audio.mp3", vad_filter=True)
105+
106+
# With custom options
107+
vad_opts = VadOptions(threshold=0.6, min_silence_duration_ms=1000)
108+
result = transcribe("audio.mp3", vad_filter=True, vad_options=vad_opts)
109+
```
110+
111+
**Requirements**: `pip install torch`
112+
113+
### Speaker Diarization
114+
115+
Identify who is speaking when with pyannote.audio. Diarization adds speaker
116+
labels to transcription segments:
117+
118+
```bash
119+
# Enable diarization (requires HuggingFace token)
120+
export HF_TOKEN=your_token
121+
mlx_whisper audio.mp3 --diarize --word-timestamps
122+
123+
# Specify speaker count
124+
mlx_whisper audio.mp3 --diarize --min-speakers 2 --max-speakers 4
125+
126+
# Output diarization in RTTM format
127+
mlx_whisper audio.mp3 --diarize -f rttm
128+
```
129+
130+
In Python:
131+
132+
```python
133+
from mlx_whisper import transcribe_with_diarization
134+
135+
result = transcribe_with_diarization(
136+
"audio.mp3",
137+
hf_token="your_token",
138+
word_timestamps=True
139+
)
140+
141+
# Access speaker info
142+
for segment in result["segments"]:
143+
speaker = segment.get("speaker", "Unknown")
144+
print(f"{speaker}: {segment['text']}")
145+
146+
# List of speakers
147+
print(result["speakers"]) # ['SPEAKER_00', 'SPEAKER_01', ...]
148+
```
149+
150+
**Requirements**:
151+
- `pip install pyannote.audio pandas`
152+
- Accept model terms at https://huggingface.co/pyannote/speaker-diarization-3.1
153+
- Set `HF_TOKEN` environment variable or pass `--hf-token`
154+
85155
### Converting models
86156

87157
> [!TIP]

whisper/mlx_whisper/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,15 @@
22

33
from . import audio, decoding, load_models
44
from ._version import __version__
5-
from .transcribe import transcribe
5+
from .transcribe import transcribe, transcribe_with_diarization
6+
7+
# Optional modules (may not be available if dependencies are missing or incompatible)
8+
try:
9+
from . import vad
10+
except (ImportError, AttributeError):
11+
vad = None
12+
13+
try:
14+
from . import diarize
15+
except (ImportError, AttributeError):
16+
diarize = None

whisper/mlx_whisper/cli.py

Lines changed: 127 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def str2bool(string):
5959
"-f",
6060
type=str,
6161
default="txt",
62-
choices=["txt", "vtt", "srt", "tsv", "json", "all"],
62+
choices=["txt", "vtt", "srt", "tsv", "json", "rttm", "all"],
6363
help="Format of the output file",
6464
)
6565
parser.add_argument(
@@ -92,6 +92,12 @@ def str2bool(string):
9292
default=5,
9393
help="Number of candidates when sampling with non-zero temperature",
9494
)
95+
parser.add_argument(
96+
"--beam-size",
97+
type=optional_int,
98+
default=None,
99+
help="Beam size for beam search (currently not implemented; option will be ignored)",
100+
)
95101
parser.add_argument(
96102
"--patience",
97103
type=float,
@@ -199,6 +205,69 @@ def str2bool(string):
199205
default="0",
200206
help="Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file",
201207
)
208+
# VAD arguments
209+
parser.add_argument(
210+
"--vad-filter",
211+
type=str2bool,
212+
default=False,
213+
help="Enable Silero VAD to filter silent audio before transcription",
214+
)
215+
parser.add_argument(
216+
"--vad-threshold",
217+
type=float,
218+
default=0.5,
219+
help="VAD speech detection threshold (0.0-1.0)",
220+
)
221+
parser.add_argument(
222+
"--vad-min-silence-ms",
223+
type=int,
224+
default=2000,
225+
help="Minimum silence duration to split speech segments (ms)",
226+
)
227+
parser.add_argument(
228+
"--vad-speech-pad-ms",
229+
type=int,
230+
default=400,
231+
help="Padding added around speech segments (ms)",
232+
)
233+
# Diarization arguments
234+
parser.add_argument(
235+
"--diarize",
236+
type=str2bool,
237+
default=False,
238+
help="Enable speaker diarization (requires pyannote.audio)",
239+
)
240+
parser.add_argument(
241+
"--hf-token",
242+
type=str,
243+
default=None,
244+
help="HuggingFace token for pyannote models (or set HF_TOKEN env var)",
245+
)
246+
parser.add_argument(
247+
"--diarize-model",
248+
type=str,
249+
default="pyannote/speaker-diarization-3.1",
250+
help="Diarization model to use",
251+
)
252+
parser.add_argument(
253+
"--min-speakers",
254+
type=optional_int,
255+
default=None,
256+
help="Minimum number of speakers for diarization",
257+
)
258+
parser.add_argument(
259+
"--max-speakers",
260+
type=optional_int,
261+
default=None,
262+
help="Maximum number of speakers for diarization",
263+
)
264+
parser.add_argument(
265+
"--diarize-device",
266+
type=str,
267+
default="cpu",
268+
choices=["cpu", "cuda", "mps"],
269+
help="Device for diarization model",
270+
)
202271
return parser
203272

204273

@@ -232,6 +301,40 @@ def main():
232301
if writer_args["max_words_per_line"] and writer_args["max_line_width"]:
233302
warnings.warn("--max-words-per-line has no effect with --max-line-width")
234303

304+
# Extract VAD options
305+
vad_filter = args.pop("vad_filter")
306+
vad_threshold = args.pop("vad_threshold")
307+
vad_min_silence_ms = args.pop("vad_min_silence_ms")
308+
vad_speech_pad_ms = args.pop("vad_speech_pad_ms")
309+
310+
vad_options = None
311+
if vad_filter:
312+
from .vad import VadOptions
313+
314+
vad_options = VadOptions(
315+
threshold=vad_threshold,
316+
min_silence_duration_ms=vad_min_silence_ms,
317+
speech_pad_ms=vad_speech_pad_ms,
318+
)
319+
elif any(
320+
[vad_threshold != 0.5, vad_min_silence_ms != 2000, vad_speech_pad_ms != 400]
321+
):
322+
warnings.warn("VAD options have no effect without --vad-filter")
323+
324+
# Extract diarization options
325+
diarize = args.pop("diarize")
326+
hf_token = args.pop("hf_token") or os.environ.get("HF_TOKEN")
327+
diarize_model = args.pop("diarize_model")
328+
min_speakers = args.pop("min_speakers")
329+
max_speakers = args.pop("max_speakers")
330+
diarize_device = args.pop("diarize_device")
331+
332+
if diarize and not hf_token:
333+
warnings.warn(
334+
"Diarization requires a HuggingFace token. "
335+
"Set --hf-token or HF_TOKEN environment variable."
336+
)
337+
235338
for audio_obj in args.pop("audio"):
236339
if audio_obj == "-":
237340
# receive the contents from stdin rather than read a file
@@ -241,11 +344,29 @@ def main():
241344
else:
242345
output_name = output_name or pathlib.Path(audio_obj).stem
243346
try:
244-
result = transcribe(
245-
audio_obj,
246-
path_or_hf_repo=path_or_hf_repo,
247-
**args,
248-
)
347+
if diarize:
348+
from .transcribe import transcribe_with_diarization
349+
350+
result = transcribe_with_diarization(
351+
audio_obj,
352+
path_or_hf_repo=path_or_hf_repo,
353+
hf_token=hf_token,
354+
diarize_model=diarize_model,
355+
min_speakers=min_speakers,
356+
max_speakers=max_speakers,
357+
device=diarize_device,
358+
vad_filter=vad_filter,
359+
vad_options=vad_options,
360+
**args,
361+
)
362+
else:
363+
result = transcribe(
364+
audio_obj,
365+
path_or_hf_repo=path_or_hf_repo,
366+
vad_filter=vad_filter,
367+
vad_options=vad_options,
368+
**args,
369+
)
249370
writer(result, output_name, **writer_args)
250371
except Exception as e:
251372
traceback.print_exc()

0 commit comments

Comments
 (0)