77from dataclasses import asdict , dataclass
88from inspect import signature
99from math import ceil
10- from typing import BinaryIO , Iterable , List , Optional , Tuple , Union
10+ from typing import Any , BinaryIO , Iterable , List , Optional , Tuple , Union
1111from warnings import warn
1212
1313import ctranslate2
@@ -82,11 +82,11 @@ class TranscriptionOptions:
8282 compression_ratio_threshold : Optional [float ]
8383 condition_on_previous_text : bool
8484 prompt_reset_on_temperature : float
85- temperatures : List [float ]
85+ temperatures : Union [ List [float ], Tuple [ float , ...] ]
8686 initial_prompt : Optional [Union [str , Iterable [int ]]]
8787 prefix : Optional [str ]
8888 suppress_blank : bool
89- suppress_tokens : Optional [List [int ]]
89+ suppress_tokens : Union [List [int ], Tuple [ int , ... ]]
9090 without_timestamps : bool
9191 max_initial_timestamp : float
9292 word_timestamps : bool
@@ -108,7 +108,7 @@ class TranscriptionInfo:
108108 duration_after_vad : float
109109 all_language_probs : Optional [List [Tuple [str , float ]]]
110110 transcription_options : TranscriptionOptions
111- vad_options : VadOptions
111+ vad_options : Optional [ VadOptions ]
112112
113113
114114class BatchedInferencePipeline :
@@ -123,7 +123,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
123123 encoder_output , outputs = self .generate_segment_batched (
124124 features , tokenizer , options
125125 )
126-
127126 segmented_outputs = []
128127 segment_sizes = []
129128 for chunk_metadata , output in zip (chunks_metadata , outputs ):
@@ -132,8 +131,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
132131 segment_sizes .append (segment_size )
133132 (
134133 subsegments ,
135- seek ,
136- single_timestamp_ending ,
134+ _ ,
135+ _ ,
137136 ) = self .model ._split_segments_by_timestamps (
138137 tokenizer = tokenizer ,
139138 tokens = output ["tokens" ],
@@ -288,7 +287,7 @@ def transcribe(
288287 hallucination_silence_threshold : Optional [float ] = None ,
289288 batch_size : int = 8 ,
290289 hotwords : Optional [str ] = None ,
291- language_detection_threshold : Optional [ float ] = 0.5 ,
290+ language_detection_threshold : float = 0.5 ,
292291 language_detection_segments : int = 1 ,
293292 ) -> Tuple [Iterable [Segment ], TranscriptionInfo ]:
294293 """transcribe audio in chunks in batched fashion and return with language info.
@@ -576,7 +575,7 @@ def __init__(
576575 num_workers : int = 1 ,
577576 download_root : Optional [str ] = None ,
578577 local_files_only : bool = False ,
579- files : dict = None ,
578+ files : Optional [ dict ] = None ,
580579 ** model_kwargs ,
581580 ):
582581 """Initializes the Whisper model.
@@ -729,7 +728,7 @@ def transcribe(
729728 clip_timestamps : Union [str , List [float ]] = "0" ,
730729 hallucination_silence_threshold : Optional [float ] = None ,
731730 hotwords : Optional [str ] = None ,
732- language_detection_threshold : Optional [ float ] = 0.5 ,
731+ language_detection_threshold : float = 0.5 ,
733732 language_detection_segments : int = 1 ,
734733 ) -> Tuple [Iterable [Segment ], TranscriptionInfo ]:
735734 """Transcribes an input file.
@@ -833,7 +832,7 @@ def transcribe(
833832 elif isinstance (vad_parameters , dict ):
834833 vad_parameters = VadOptions (** vad_parameters )
835834 speech_chunks = get_speech_timestamps (audio , vad_parameters )
836- audio_chunks , chunks_metadata = collect_chunks (audio , speech_chunks )
835+ audio_chunks , _ = collect_chunks (audio , speech_chunks )
837836 audio = np .concatenate (audio_chunks , axis = 0 )
838837 duration_after_vad = audio .shape [0 ] / sampling_rate
839838
@@ -933,7 +932,7 @@ def transcribe(
933932 condition_on_previous_text = condition_on_previous_text ,
934933 prompt_reset_on_temperature = prompt_reset_on_temperature ,
935934 temperatures = (
936- temperature if isinstance (temperature , (list , tuple )) else [temperature ]
935+ temperature if isinstance (temperature , (List , Tuple )) else [temperature ]
937936 ),
938937 initial_prompt = initial_prompt ,
939938 prefix = prefix ,
@@ -962,7 +961,8 @@ def transcribe(
962961
963962 if speech_chunks :
964963 segments = restore_speech_timestamps (segments , speech_chunks , sampling_rate )
965-
964+ if isinstance (vad_parameters , dict ):
965+ vad_parameters = VadOptions (** vad_parameters )
966966 info = TranscriptionInfo (
967967 language = language ,
968968 language_probability = language_probability ,
@@ -982,7 +982,7 @@ def _split_segments_by_timestamps(
982982 segment_size : int ,
983983 segment_duration : float ,
984984 seek : int ,
985- ) -> List [List [int ] ]:
985+ ) -> Tuple [List [Any ], int , bool ]:
986986 current_segments = []
987987 single_timestamp_ending = (
988988 len (tokens ) >= 2 and tokens [- 2 ] < tokenizer .timestamp_begin <= tokens [- 1 ]
@@ -1550,8 +1550,8 @@ def add_word_timestamps(
15501550 num_frames : int ,
15511551 prepend_punctuations : str ,
15521552 append_punctuations : str ,
1553- last_speech_timestamp : float ,
1554- ) -> float :
1553+ last_speech_timestamp : Union [ float , None ] ,
1554+ ) -> Optional [ float ] :
15551555 if len (segments ) == 0 :
15561556 return
15571557
@@ -1698,9 +1698,11 @@ def find_alignment(
16981698 text_indices = np .array ([pair [0 ] for pair in alignments ])
16991699 time_indices = np .array ([pair [1 ] for pair in alignments ])
17001700
1701- words , word_tokens = tokenizer .split_to_word_tokens (
1702- text_token + [tokenizer .eot ]
1703- )
1701+ if isinstance (text_token , int ):
1702+ tokens = [text_token ] + [tokenizer .eot ]
1703+ else :
1704+ tokens = text_token + [tokenizer .eot ]
1705+ words , word_tokens = tokenizer .split_to_word_tokens (tokens )
17041706 if len (word_tokens ) <= 1 :
17051707 # return on eot only
17061708 # >>> np.pad([], (1, 0))
@@ -1746,7 +1748,7 @@ def detect_language(
17461748 audio : Optional [np .ndarray ] = None ,
17471749 features : Optional [np .ndarray ] = None ,
17481750 vad_filter : bool = False ,
1749- vad_parameters : Union [dict , VadOptions ] = None ,
1751+ vad_parameters : Optional [ Union [dict , VadOptions ] ] = None ,
17501752 language_detection_segments : int = 1 ,
17511753 language_detection_threshold : float = 0.5 ,
17521754 ) -> Tuple [str , float , List [Tuple [str , float ]]]:
@@ -1778,18 +1780,24 @@ def detect_language(
17781780 if audio is not None :
17791781 if vad_filter :
17801782 speech_chunks = get_speech_timestamps (audio , vad_parameters )
1781- audio_chunks , chunks_metadata = collect_chunks (audio , speech_chunks )
1783+ audio_chunks , _ = collect_chunks (audio , speech_chunks )
17821784 audio = np .concatenate (audio_chunks , axis = 0 )
1783-
1785+ assert (
1786+ audio is not None
1787+ ), "Audio have a problem while concatanating the audio_chunks; return None"
17841788 audio = audio [
17851789 : language_detection_segments * self .feature_extractor .n_samples
17861790 ]
17871791 features = self .feature_extractor (audio )
1788-
1792+ assert (
1793+ features is not None
1794+ ), "No features extracted from audio file; return None"
17891795 features = features [
17901796 ..., : language_detection_segments * self .feature_extractor .nb_max_frames
17911797 ]
1792-
1798+ assert (
1799+ features is not None
1800+ ), "No features extracted when detectting language in audio segments; return None"
17931801 detected_language_info = {}
17941802 for i in range (0 , features .shape [- 1 ], self .feature_extractor .nb_max_frames ):
17951803 encoder_output = self .encode (
@@ -1859,13 +1867,13 @@ def get_compression_ratio(text: str) -> float:
18591867
18601868def get_suppressed_tokens (
18611869 tokenizer : Tokenizer ,
1862- suppress_tokens : Tuple [int ],
1863- ) -> Optional [List [int ]]:
1864- if - 1 in suppress_tokens :
1870+ suppress_tokens : Optional [List [int ]],
1871+ ) -> Tuple [int , ...]:
1872+ if suppress_tokens is None or len (suppress_tokens ) == 0 :
1873+ suppress_tokens = [] # interpret empty string as an empty list
1874+ elif - 1 in suppress_tokens :
18651875 suppress_tokens = [t for t in suppress_tokens if t >= 0 ]
18661876 suppress_tokens .extend (tokenizer .non_speech_tokens )
1867- elif suppress_tokens is None or len (suppress_tokens ) == 0 :
1868- suppress_tokens = [] # interpret empty string as an empty list
18691877 else :
18701878 assert isinstance (suppress_tokens , list ), "suppress_tokens must be a list"
18711879
0 commit comments