Skip to content

Commit d039127

Browse files
committed
feat: solve conflicts
1 parent be9fb36 commit d039127

File tree

2 files changed

+41
-32
lines changed

2 files changed

+41
-32
lines changed

faster_whisper/transcribe.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import asdict, dataclass
88
from inspect import signature
99
from math import ceil
10-
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
10+
from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Union
1111
from warnings import warn
1212

1313
import ctranslate2
@@ -82,11 +82,11 @@ class TranscriptionOptions:
8282
compression_ratio_threshold: Optional[float]
8383
condition_on_previous_text: bool
8484
prompt_reset_on_temperature: float
85-
temperatures: List[float]
85+
temperatures: Union[List[float], Tuple[float, ...]]
8686
initial_prompt: Optional[Union[str, Iterable[int]]]
8787
prefix: Optional[str]
8888
suppress_blank: bool
89-
suppress_tokens: Optional[List[int]]
89+
suppress_tokens: Union[List[int], Tuple[int, ...]]
9090
without_timestamps: bool
9191
max_initial_timestamp: float
9292
word_timestamps: bool
@@ -108,7 +108,7 @@ class TranscriptionInfo:
108108
duration_after_vad: float
109109
all_language_probs: Optional[List[Tuple[str, float]]]
110110
transcription_options: TranscriptionOptions
111-
vad_options: VadOptions
111+
vad_options: Optional[VadOptions]
112112

113113

114114
class BatchedInferencePipeline:
@@ -123,7 +123,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
123123
encoder_output, outputs = self.generate_segment_batched(
124124
features, tokenizer, options
125125
)
126-
127126
segmented_outputs = []
128127
segment_sizes = []
129128
for chunk_metadata, output in zip(chunks_metadata, outputs):
@@ -132,8 +131,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
132131
segment_sizes.append(segment_size)
133132
(
134133
subsegments,
135-
seek,
136-
single_timestamp_ending,
134+
_,
135+
_,
137136
) = self.model._split_segments_by_timestamps(
138137
tokenizer=tokenizer,
139138
tokens=output["tokens"],
@@ -288,7 +287,7 @@ def transcribe(
288287
hallucination_silence_threshold: Optional[float] = None,
289288
batch_size: int = 8,
290289
hotwords: Optional[str] = None,
291-
language_detection_threshold: Optional[float] = 0.5,
290+
language_detection_threshold: float = 0.5,
292291
language_detection_segments: int = 1,
293292
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
294293
"""transcribe audio in chunks in batched fashion and return with language info.
@@ -576,7 +575,7 @@ def __init__(
576575
num_workers: int = 1,
577576
download_root: Optional[str] = None,
578577
local_files_only: bool = False,
579-
files: dict = None,
578+
files: Optional[dict] = None,
580579
**model_kwargs,
581580
):
582581
"""Initializes the Whisper model.
@@ -729,7 +728,7 @@ def transcribe(
729728
clip_timestamps: Union[str, List[float]] = "0",
730729
hallucination_silence_threshold: Optional[float] = None,
731730
hotwords: Optional[str] = None,
732-
language_detection_threshold: Optional[float] = 0.5,
731+
language_detection_threshold: float = 0.5,
733732
language_detection_segments: int = 1,
734733
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
735734
"""Transcribes an input file.
@@ -833,7 +832,7 @@ def transcribe(
833832
elif isinstance(vad_parameters, dict):
834833
vad_parameters = VadOptions(**vad_parameters)
835834
speech_chunks = get_speech_timestamps(audio, vad_parameters)
836-
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
835+
audio_chunks, _ = collect_chunks(audio, speech_chunks)
837836
audio = np.concatenate(audio_chunks, axis=0)
838837
duration_after_vad = audio.shape[0] / sampling_rate
839838

@@ -933,7 +932,7 @@ def transcribe(
933932
condition_on_previous_text=condition_on_previous_text,
934933
prompt_reset_on_temperature=prompt_reset_on_temperature,
935934
temperatures=(
936-
temperature if isinstance(temperature, (list, tuple)) else [temperature]
935+
temperature if isinstance(temperature, (List, Tuple)) else [temperature]
937936
),
938937
initial_prompt=initial_prompt,
939938
prefix=prefix,
@@ -962,7 +961,8 @@ def transcribe(
962961

963962
if speech_chunks:
964963
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
965-
964+
if isinstance(vad_parameters, dict):
965+
vad_parameters = VadOptions(**vad_parameters)
966966
info = TranscriptionInfo(
967967
language=language,
968968
language_probability=language_probability,
@@ -982,7 +982,7 @@ def _split_segments_by_timestamps(
982982
segment_size: int,
983983
segment_duration: float,
984984
seek: int,
985-
) -> List[List[int]]:
985+
) -> Tuple[List[Any], int, bool]:
986986
current_segments = []
987987
single_timestamp_ending = (
988988
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
@@ -1550,8 +1550,8 @@ def add_word_timestamps(
15501550
num_frames: int,
15511551
prepend_punctuations: str,
15521552
append_punctuations: str,
1553-
last_speech_timestamp: float,
1554-
) -> float:
1553+
last_speech_timestamp: Union[float, None],
1554+
) -> Optional[float]:
15551555
if len(segments) == 0:
15561556
return
15571557

@@ -1698,9 +1698,11 @@ def find_alignment(
16981698
text_indices = np.array([pair[0] for pair in alignments])
16991699
time_indices = np.array([pair[1] for pair in alignments])
17001700

1701-
words, word_tokens = tokenizer.split_to_word_tokens(
1702-
text_token + [tokenizer.eot]
1703-
)
1701+
if isinstance(text_token, int):
1702+
tokens = [text_token] + [tokenizer.eot]
1703+
else:
1704+
tokens = text_token + [tokenizer.eot]
1705+
words, word_tokens = tokenizer.split_to_word_tokens(tokens)
17041706
if len(word_tokens) <= 1:
17051707
# return on eot only
17061708
# >>> np.pad([], (1, 0))
@@ -1746,7 +1748,7 @@ def detect_language(
17461748
audio: Optional[np.ndarray] = None,
17471749
features: Optional[np.ndarray] = None,
17481750
vad_filter: bool = False,
1749-
vad_parameters: Union[dict, VadOptions] = None,
1751+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
17501752
language_detection_segments: int = 1,
17511753
language_detection_threshold: float = 0.5,
17521754
) -> Tuple[str, float, List[Tuple[str, float]]]:
@@ -1778,18 +1780,24 @@ def detect_language(
17781780
if audio is not None:
17791781
if vad_filter:
17801782
speech_chunks = get_speech_timestamps(audio, vad_parameters)
1781-
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
1783+
audio_chunks, _ = collect_chunks(audio, speech_chunks)
17821784
audio = np.concatenate(audio_chunks, axis=0)
1783-
1785+
assert (
1786+
audio is not None
1787+
), "Audio have a problem while concatanating the audio_chunks; return None"
17841788
audio = audio[
17851789
: language_detection_segments * self.feature_extractor.n_samples
17861790
]
17871791
features = self.feature_extractor(audio)
1788-
1792+
assert (
1793+
features is not None
1794+
), "No features extracted from audio file; return None"
17891795
features = features[
17901796
..., : language_detection_segments * self.feature_extractor.nb_max_frames
17911797
]
1792-
1798+
assert (
1799+
features is not None
1800+
), "No features extracted when detectting language in audio segments; return None"
17931801
detected_language_info = {}
17941802
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
17951803
encoder_output = self.encode(
@@ -1859,13 +1867,13 @@ def get_compression_ratio(text: str) -> float:
18591867

18601868
def get_suppressed_tokens(
18611869
tokenizer: Tokenizer,
1862-
suppress_tokens: Tuple[int],
1863-
) -> Optional[List[int]]:
1864-
if -1 in suppress_tokens:
1870+
suppress_tokens: Optional[List[int]],
1871+
) -> Tuple[int, ...]:
1872+
if suppress_tokens is None or len(suppress_tokens) == 0:
1873+
suppress_tokens = [] # interpret empty string as an empty list
1874+
elif -1 in suppress_tokens:
18651875
suppress_tokens = [t for t in suppress_tokens if t >= 0]
18661876
suppress_tokens.extend(tokenizer.non_speech_tokens)
1867-
elif suppress_tokens is None or len(suppress_tokens) == 0:
1868-
suppress_tokens = [] # interpret empty string as an empty list
18691877
else:
18701878
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
18711879

faster_whisper/vad.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
from dataclasses import dataclass
6-
from typing import Dict, List, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple, Union
77

88
import numpy as np
99

@@ -44,7 +44,7 @@ class VadOptions:
4444

4545
def get_speech_timestamps(
4646
audio: np.ndarray,
47-
vad_options: Optional[VadOptions] = None,
47+
vad_options: Optional[Union[dict, VadOptions]] = None,
4848
sampling_rate: int = 16000,
4949
**kwargs,
5050
) -> List[dict]:
@@ -61,7 +61,8 @@ def get_speech_timestamps(
6161
"""
6262
if vad_options is None:
6363
vad_options = VadOptions(**kwargs)
64-
64+
if isinstance(vad_options, dict):
65+
vad_options = VadOptions(**vad_options)
6566
onset = vad_options.onset
6667
min_speech_duration_ms = vad_options.min_speech_duration_ms
6768
max_speech_duration_s = vad_options.max_speech_duration_s

0 commit comments

Comments
 (0)