Skip to content

Commit 1024749

Browse files
committed
utilise the Silero VAD model on PyTorch Hub
1 parent e147990 commit 1024749

File tree

4 files changed

+21
-14
lines changed

4 files changed

+21
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ dependencies = [
8484
"tf-keras~=2.19.0; python_version >= '3.12'",
8585
"tensorflow>=1.15.5,<2.16.0; python_version < '3.12'",
8686
"tensorflow~=2.19.0; python_version >= '3.12'",
87+
"tensorflow-metal~=1.2.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
8788
"termcolor==1.1.0",
8889
"toml==0.10.0",
8990
"toolz==0.9.0",
@@ -106,7 +107,6 @@ harmony = [
106107
"accelerate~=1.12.0",
107108
"soxr==1.0.0",
108109
"webrtcvad~=2.0.10",
109-
"silero-vad~=6.2.0",
110110
]
111111
dev = [
112112
"aeneas~=1.7.3.0; python_version < '3.12'",
@@ -120,7 +120,6 @@ dev = [
120120
"soxr==1.0.0",
121121
"accelerate~=1.12.0",
122122
"webrtcvad~=2.0.10",
123-
"silero-vad~=6.2.0",
124123
"mock==4.0.3",
125124
"coverage==5.5",
126125
"tox~=3.23.0",
@@ -159,7 +158,6 @@ llm = [
159158
"accelerate~=1.12.0",
160159
"soxr==1.0.0",
161160
"webrtcvad~=2.0.10",
162-
"silero-vad~=6.2.0",
163161
]
164162

165163
[project.scripts]

subaligner/transcriber.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str
7373
self.__flavour = flavour
7474
self.__media_helper = MediaHelper()
7575
self.__lock = Lock()
76+
self.vad_model: Optional[Any] = None
7677

7778
if recipe == TranscriptionRecipe.WHISPER.value:
7879
if flavour not in [f.value for f in WhisperFlavour]:
@@ -143,7 +144,9 @@ def transcribe(self,
143144
self.__LOGGER.debug("Prompting with: '%s'" % initial_prompt)
144145

145146
audio, sr = self.__load_audio(audio_file_path, target_sample_rate=sample_rate)
146-
segments = Utils.vad_segment(audio, sample_rate=sr, recipe="silero")
147+
segments, self.vad_model = Utils.vad_segment(
148+
audio, sample_rate=sr, recipe="silero", model_local=self.vad_model
149+
)
147150
self.__LOGGER.info("Segments detected with voice activities")
148151

149152
final_segments = []

subaligner/utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -778,7 +778,8 @@ def vad_segment(audio: np.ndarray,
778778
frame_ms: int = 30,
779779
aggressiveness: int = 3,
780780
min_speech_ms: int = 200,
781-
recipe: str = "webrtcvad") -> List[Tuple[int, int]]:
781+
recipe: str = "webrtcvad",
782+
model_local: Optional[Any] = None) -> Tuple[List[Tuple[int, int]], Any]:
782783
"""Segment audio into speech and non-speech segments using WebRTC VAD.
783784
784785
Arguments:
@@ -788,9 +789,9 @@ def vad_segment(audio: np.ndarray,
788789
aggressiveness {int} -- The aggressiveness of the VAD (0-3).
789790
min_speech_ms {int} -- The minimum duration of a speech segment in milliseconds.
790791
recipe {str} -- The VAD recipe to use ("webrtcvad" or "silero").
791-
792+
model_local {Optional[Any]} -- The loaded VAD model.
792793
Returns:
793-
List[Tuple[int, int]]: A list of tuples representing the start and end samples of speech segments.
794+
Tuple[List[Tuple[int, int]], Any]: A list of tuples representing the start and end samples of speech segments, and the loaded VAD model.
794795
795796
Raises:
796797
ValueError: If an unsupported VAD recipe is provided.
@@ -831,20 +832,25 @@ def vad_segment(audio: np.ndarray,
831832
if cur_start is not None and cur_end is not None:
832833
if (cur_end - cur_start) >= int(min_speech_ms * sample_rate / 1000):
833834
segments.append((cur_start, cur_end))
834-
return segments
835+
return segments, model_local
835836
elif recipe == "silero":
836-
from silero_vad import load_silero_vad, get_speech_timestamps
837-
model = load_silero_vad()
837+
if model_local is None:
838+
model_local, utils = torch.hub.load(
839+
repo_or_dir="snakers4/silero-vad:be95df9152c0d7618fa1edfeb296fc3dae32376f", # v6.2
840+
model="silero_vad",
841+
force_reload=False,
842+
)
843+
(get_speech_timestamps, _, read_audio, *_) = utils
838844
speech_timestamps = get_speech_timestamps(
839845
torch.tensor(audio, dtype=torch.float32),
840-
model,
846+
model_local,
841847
sampling_rate=sample_rate,
842848
return_seconds=True,
843849
)
844850
segments = []
845851
for ts in speech_timestamps:
846852
segments.append((int(ts['start'] * sample_rate), int(ts['end'] * sample_rate)))
847-
return segments
853+
return segments, model_local
848854
else:
849855
raise ValueError("Unsupported VAD recipe: {}".format(recipe))
850856

tests/subaligner/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ def test_vad_segment_webrtcvad(self):
427427
audio = audio.astype("float32") / maxv
428428
audio = audio.astype("float32")
429429

430-
segments = Undertest.vad_segment(
430+
segments, _ = Undertest.vad_segment(
431431
audio, sample_rate=sr, frame_ms=30, aggressiveness=2, min_speech_ms=300, recipe="webrtcvad"
432432
)
433433

@@ -449,7 +449,7 @@ def test_vad_segment_silero(self):
449449
audio = audio.astype("float32") / maxv
450450
audio = audio.astype("float32")
451451

452-
segments = Undertest.vad_segment(audio, sample_rate=sr, recipe="silero")
452+
segments, _ = Undertest.vad_segment(audio, sample_rate=sr, recipe="silero")
453453

454454
self.assertGreater(len(segments), 0)
455455
for start, end in segments:

0 commit comments

Comments
 (0)