diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index 51eb1c50..b051800c 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -253,7 +253,7 @@ def generate_segment_batched( def transcribe( self, - audio: Union[str, BinaryIO, np.ndarray], + audio: Union[str, BinaryIO, np.ndarray, List[Union[str, BinaryIO, np.ndarray]]], language: Optional[str] = None, task: str = "transcribe", log_progress: bool = False, @@ -297,10 +297,14 @@ def transcribe( language_detection_threshold: Optional[float] = 0.5, language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: - """transcribe audio in chunks in batched fashion and return with language info. + """Transcribe audio in chunks in batched fashion and return with language info. + + Supports both single audio and batch processing of multiple audios. + When a list of audios is provided, returns a list of results (one per audio). Arguments: audio: Path to the input file (or a file-like object), or the audio waveform. + Can also be a list of audio inputs for batch processing. language: The language spoken in the audio. It should be a language code such as "en" or "fr". If not set, the language will be detected in the first 30 seconds of audio. @@ -369,12 +373,15 @@ def transcribe( (in seconds) when a possible hallucination is detected. set as None. Returns: A tuple with: - - a generator over transcribed segments - an instance of TranscriptionInfo """ + is_batch = isinstance(audio, list) + audios = audio if is_batch else [audio] + sampling_rate = self.model.feature_extractor.sampling_rate + _chunk_length = chunk_length or self.model.feature_extractor.chunk_length if multilingual and not self.model.model.is_multilingual: self.model.logger.warning( @@ -383,91 +390,310 @@ def transcribe( ) multilingual = False - if not isinstance(audio, np.ndarray): - audio = decode_audio(audio, sampling_rate=sampling_rate) - duration = audio.shape[0] / sampling_rate + _vad_parameters = None + if vad_filter: + if vad_parameters is None: + _vad_parameters = VadOptions( + max_speech_duration_s=_chunk_length, + min_silence_duration_ms=160, + ) + elif isinstance(vad_parameters, dict): + vad_params_copy = vad_parameters.copy() + if "max_speech_duration_s" in vad_params_copy.keys(): + vad_params_copy.pop("max_speech_duration_s") + _vad_parameters = VadOptions( + **vad_params_copy, max_speech_duration_s=_chunk_length + ) + else: + _vad_parameters = vad_parameters - self.model.logger.info( - "Processing audio with duration %s", format_timestamp(duration) - ) + all_features = [] + all_chunks_metadata = [] + audio_infos = [] + audio_boundaries = [0] - chunk_length = chunk_length or self.model.feature_extractor.chunk_length - # if no segment split is provided, use vad_model and generate segments - if not clip_timestamps: - if vad_filter: - if vad_parameters is None: - vad_parameters = VadOptions( - max_speech_duration_s=chunk_length, - min_silence_duration_ms=160, - ) - elif isinstance(vad_parameters, dict): - if "max_speech_duration_s" in vad_parameters.keys(): - vad_parameters.pop("max_speech_duration_s") + for audio_item in audios: + if not isinstance(audio_item, np.ndarray): + audio_item = decode_audio(audio_item, sampling_rate=sampling_rate) - vad_parameters = VadOptions( - **vad_parameters, max_speech_duration_s=chunk_length - ) + duration = audio_item.shape[0] / sampling_rate + + self.model.logger.info( + "Processing audio with duration %s", format_timestamp(duration) + ) + + audio_clip_timestamps = None + clip_timestamps_provided = clip_timestamps is not None - clip_timestamps = get_speech_timestamps(audio, vad_parameters) - # run the audio if it is less than 30 sec even without clip_timestamps - elif duration < chunk_length: - clip_timestamps = [{"start": 0, "end": audio.shape[0]}] + if clip_timestamps_provided: + audio_clip_timestamps = [ + {k: int(v * sampling_rate) for k, v in segment.items()} + for segment in clip_timestamps + ] + audio_chunks, chunks_meta = [], [] + for i, clip in enumerate(audio_clip_timestamps): + audio_chunks.append(audio_item[clip["start"] : clip["end"]]) + clip_duration = (clip["end"] - clip["start"]) / sampling_rate + if clip_duration > 30: + self.model.logger.warning( + "Segment %d is longer than 30 seconds, " + "only the first 30 seconds will be transcribed", + i, + ) + chunks_meta.append( + { + "offset": clip["start"] / sampling_rate, + "duration": clip_duration, + "segments": [clip], + } + ) + elif vad_filter: + audio_clip_timestamps = get_speech_timestamps( + audio_item, _vad_parameters + ) + audio_chunks, chunks_meta = collect_chunks( + audio_item, audio_clip_timestamps, max_duration=_chunk_length + ) + elif duration < _chunk_length: + audio_clip_timestamps = [{"start": 0, "end": audio_item.shape[0]}] + audio_chunks, chunks_meta = collect_chunks( + audio_item, audio_clip_timestamps, max_duration=_chunk_length + ) else: raise RuntimeError( "No clip timestamps found. " "Set 'vad_filter' to True or provide 'clip_timestamps'." ) - clip_timestamps_provided = False - audio_chunks, chunks_metadata = collect_chunks( - audio, clip_timestamps, max_duration=chunk_length + duration_after_vad = ( + sum( + (segment["end"] - segment["start"]) + for segment in audio_clip_timestamps + ) + / sampling_rate ) - else: - clip_timestamps_provided = True - clip_timestamps = [ - {k: int(v * sampling_rate) for k, v in segment.items()} - for segment in clip_timestamps - ] + self.model.logger.info( + "VAD filter removed %s of audio", + format_timestamp(duration - duration_after_vad), + ) + features = ( + [ + self.model.feature_extractor(chunk)[..., :-1] + for chunk in audio_chunks + ] + if duration_after_vad + else [] + ) - audio_chunks, chunks_metadata = [], [] - for i, clip in enumerate(clip_timestamps): - audio_chunks.append(audio[clip["start"] : clip["end"]]) + audio_infos.append( + { + "duration": duration, + "duration_after_vad": duration_after_vad, + "clip_timestamps": audio_clip_timestamps, + } + ) - clip_duration = (clip["end"] - clip["start"]) / sampling_rate - if clip_duration > 30: - self.model.logger.warning( - "Segment %d is longer than 30 seconds, " - "only the first 30 seconds will be transcribed", - i, - ) + all_features.extend(features) + all_chunks_metadata.extend(chunks_meta) + audio_boundaries.append(len(all_features)) - chunks_metadata.append( - { - "offset": clip["start"] / sampling_rate, - "duration": clip_duration, - "segments": [clip], - } + all_language_probs = None + if language is None: + if not self.model.model.is_multilingual: + language = "en" + language_probability = 1 + else: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language( + features=np.concatenate( + all_features + + [ + np.full((self.model.model.n_mels, 1), -1.5, dtype="float32") + ], + axis=1, + ), + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, + ) + + self.model.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + if not self.model.model.is_multilingual and language != "en": + self.model.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language ) + language = "en" + + language_probability = 1 + + tokenizer = Tokenizer( + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task=task, + language=language, + ) + + all_features = ( + np.stack([pad_or_trim(feature) for feature in all_features]) + if all_features + else [] + ) + + options = TranscriptionOptions( + beam_size=beam_size, + best_of=best_of, + patience=patience, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + compression_ratio_threshold=compression_ratio_threshold, + temperatures=( + temperature[:1] + if isinstance(temperature, (list, tuple)) + else [temperature] + ), + initial_prompt=initial_prompt, + prefix=prefix, + suppress_blank=suppress_blank, + suppress_tokens=( + get_suppressed_tokens(tokenizer, suppress_tokens) + if suppress_tokens + else suppress_tokens + ), + prepend_punctuations=prepend_punctuations, + append_punctuations=append_punctuations, + max_new_tokens=max_new_tokens, + hotwords=hotwords, + word_timestamps=word_timestamps, + hallucination_silence_threshold=None, + condition_on_previous_text=False, + clip_timestamps=clip_timestamps, + prompt_reset_on_temperature=0.5, + multilingual=multilingual, + without_timestamps=without_timestamps, + max_initial_timestamp=0.0, + ) + + clip_timestamps_provided = clip_timestamps is not None + + info = TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=audio_infos[0]["duration"], + duration_after_vad=audio_infos[0]["duration_after_vad"], + transcription_options=options, + vad_options=_vad_parameters, + all_language_probs=all_language_probs, + ) - duration_after_vad = ( - sum((segment["end"] - segment["start"]) for segment in clip_timestamps) - / sampling_rate + segments = self._batched_segments_generator( + all_features, + tokenizer, + all_chunks_metadata, + batch_size, + options, + log_progress, ) - self.model.logger.info( - "VAD filter removed %s of audio", - format_timestamp(duration - duration_after_vad), + if not is_batch and not clip_timestamps_provided: + segments = restore_speech_timestamps( + segments, audio_infos[0]["clip_timestamps"], sampling_rate + ) + + return segments, info + + def transcribe_batch_multiple_audios( + self, + audios: List[Union[str, BinaryIO, np.ndarray]], + language: Optional[str] = None, + task: str = "transcribe", + log_progress: bool = False, + beam_size: int = 5, + best_of: int = 5, + patience: float = 1, + length_penalty: float = 1, + repetition_penalty: float = 1, + no_repeat_ngram_size: int = 0, + temperature: Union[float, List[float], Tuple[float, ...]] = [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + ], + compression_ratio_threshold: Optional[float] = 2.4, + log_prob_threshold: Optional[float] = -1.0, + no_speech_threshold: Optional[float] = 0.6, + condition_on_previous_text: bool = True, + prompt_reset_on_temperature: float = 0.5, + initial_prompt: Optional[Union[str, Iterable[int]]] = None, + prefix: Optional[str] = None, + suppress_blank: bool = True, + suppress_tokens: Optional[List[int]] = [-1], + without_timestamps: bool = True, + max_initial_timestamp: float = 1.0, + word_timestamps: bool = False, + prepend_punctuations: str = "\"'" "¿([{-", + append_punctuations: str = "\"'.。,,!!??::”)]}、", + multilingual: bool = False, + vad_filter: bool = False, + vad_parameters: Optional[Union[dict, VadOptions]] = None, + max_new_tokens: Optional[int] = None, + chunk_length: Optional[int] = None, + clip_timestamps: Optional[List[dict]] = None, + hallucination_silence_threshold: Optional[float] = None, + batch_size: int = 8, + hotwords: Optional[str] = None, + language_detection_threshold: Optional[float] = 0.5, + language_detection_segments: int = 1, + text_only: bool = False, + ) -> Tuple[Iterable[Segment], TranscriptionInfo]: + """Deprecated: Use transcribe() with a list of audios instead.""" + warn( + "transcribe_batch_multiple_audios() is deprecated, " + "use transcribe() with a list of audios instead", + DeprecationWarning, + 2, ) + sampling_rate = self.model.feature_extractor.sampling_rate + + if multilingual and not self.model.model.is_multilingual: + self.model.logger.warning( + "The current model is English-only but the multilingual parameter is set to" + "True; setting to False instead." + ) + multilingual = False + + processed_audios = [] + for audio in audios: + if not isinstance(audio, np.ndarray): + audio = decode_audio(audio, sampling_rate=sampling_rate) + processed_audios.append(audio) + + features = [] + for audio in processed_audios: + feature = self.model.feature_extractor(audio)[..., :-1] + features.append(feature) + features = ( - [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks] - if duration_after_vad - else [] + np.stack([pad_or_trim(feature) for feature in features]) if features else [] ) all_language_probs = None - # detecting the language if not provided + if language is None: if not self.model.model.is_multilingual: language = "en" @@ -484,7 +710,7 @@ def transcribe( np.full((self.model.model.n_mels, 1), -1.5, dtype="float32") ], axis=1, - ), # add a dummy feature to account for empty audio + ), language_detection_segments=language_detection_segments, language_detection_threshold=language_detection_threshold, ) @@ -511,10 +737,6 @@ def transcribe( language=language, ) - features = ( - np.stack([pad_or_trim(feature) for feature in features]) if features else [] - ) - options = TranscriptionOptions( beam_size=beam_size, best_of=best_of, @@ -555,13 +777,18 @@ def transcribe( info = TranscriptionInfo( language=language, language_probability=language_probability, - duration=duration, - duration_after_vad=duration_after_vad, + duration=processed_audios[0].shape[0] / sampling_rate, + duration_after_vad=processed_audios[0].shape[0] / sampling_rate, transcription_options=options, vad_options=vad_parameters, all_language_probs=all_language_probs, ) + chunks_metadata = [ + {"offset": 0, "duration": a.shape[0] / sampling_rate} + for a in processed_audios + ] + segments = self._batched_segments_generator( features, tokenizer, @@ -569,19 +796,27 @@ def transcribe( batch_size, options, log_progress, + text_only, ) - if not clip_timestamps_provided: - segments = restore_speech_timestamps( - segments, clip_timestamps, sampling_rate - ) return segments, info def _batched_segments_generator( - self, features, tokenizer, chunks_metadata, batch_size, options, log_progress + self, + features, + tokenizer, + chunks_metadata, + batch_size, + options, + log_progress, + text_only=False, ): + """ + Optimized version that can return text only results + """ pbar = tqdm(total=len(features), disable=not log_progress, position=0) seg_idx = 0 + for i in range(0, len(features), batch_size): results = self.forward( features[i : i + batch_size], @@ -593,29 +828,118 @@ def _batched_segments_generator( for result in results: for segment in result: seg_idx += 1 - yield Segment( - seek=segment["seek"], - id=seg_idx, - text=segment["text"], - start=round(segment["start"], 3), - end=round(segment["end"], 3), - words=( - None - if not options.word_timestamps - else [Word(**word) for word in segment["words"]] - ), - tokens=segment["tokens"], - avg_logprob=segment["avg_logprob"], - no_speech_prob=segment["no_speech_prob"], - compression_ratio=segment["compression_ratio"], - temperature=options.temperatures[0], - ) + if text_only: + # return text only + yield segment["text"] + else: + yield Segment( + seek=segment["seek"], + id=seg_idx, + text=segment["text"], + start=round(segment["start"], 3), + end=round(segment["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in segment["words"]] + ), + tokens=segment["tokens"], + avg_logprob=segment["avg_logprob"], + no_speech_prob=segment["no_speech_prob"], + compression_ratio=segment["compression_ratio"], + temperature=options.temperatures[0], + ) pbar.update(1) pbar.close() self.last_speech_timestamp = 0.0 + def _batched_segments_generator_grouped( + self, + features, + tokenizer, + chunks_metadata, + audio_boundaries, + batch_size, + options, + log_progress, + ): + """ + Process batched features and return segments grouped by audio. + + Args: + features: Stacked features for all audios + tokenizer: Tokenizer instance + chunks_metadata: Metadata for all chunks + audio_boundaries: List of indices marking where each audio's chunks start/end + batch_size: Batch size for processing + options: Transcription options + log_progress: Whether to show progress bar + + Returns: + List of lists, where each inner list contains Segment objects for one audio + """ + if len(features) == 0: + # Return empty lists for each audio + return [[] for _ in range(len(audio_boundaries) - 1)] + + pbar = tqdm(total=len(features), disable=not log_progress, position=0) + + # Collect all results first + all_results = [] + for i in range(0, len(features), batch_size): + results = self.forward( + features[i : i + batch_size], + tokenizer, + chunks_metadata[i : i + batch_size], + options, + ) + all_results.extend(results) + pbar.update(len(results)) + + pbar.close() + self.last_speech_timestamp = 0.0 + + # Group results by audio using boundaries + num_audios = len(audio_boundaries) - 1 + grouped_segments = [] + + for audio_idx in range(num_audios): + start_idx = audio_boundaries[audio_idx] + end_idx = audio_boundaries[audio_idx + 1] + + audio_segments = [] + seg_idx = 0 + + for chunk_idx in range(start_idx, end_idx): + if chunk_idx < len(all_results): + for segment in all_results[chunk_idx]: + seg_idx += 1 + audio_segments.append( + Segment( + seek=segment["seek"], + id=seg_idx, + text=segment["text"], + start=round(segment["start"], 3), + end=round(segment["end"], 3), + words=( + None + if not options.word_timestamps + else [Word(**word) for word in segment["words"]] + ), + tokens=segment["tokens"], + avg_logprob=segment["avg_logprob"], + no_speech_prob=segment["no_speech_prob"], + compression_ratio=segment["compression_ratio"], + temperature=options.temperatures[0], + ) + ) + + grouped_segments.append(audio_segments) + + return grouped_segments + class WhisperModel: def __init__( diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 89017747..c2f6460b 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -271,6 +271,91 @@ def test_monotonic_timestamps(physcisworks_path): assert segments[-1].end <= info.duration +def test_transcribe_batch_multiple_audios(physcisworks_path): + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + result, info = batched_model.transcribe_batch_multiple_audios( + [physcisworks_path, physcisworks_path, physcisworks_path], batch_size=16 + ) + + assert info.language == "en" + assert info.language_probability > 0.7 + segments = [] + for segment in result: + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + + assert len(segments) == 3 + + result, info = batched_model.transcribe_batch_multiple_audios( + [physcisworks_path, physcisworks_path, physcisworks_path], + batch_size=3, + without_timestamps=False, + word_timestamps=True, + ) + segments = [] + for segment in result: + assert segment.words is not None + segments.append( + {"start": segment.start, "end": segment.end, "text": segment.text} + ) + assert len(segments) > 3 + + +def test_transcribe_multiple_audios(jfk_path): + """Test transcribe() with a list of multiple audios.""" + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + + result, info = batched_model.transcribe( + [jfk_path, jfk_path, jfk_path], + batch_size=8, + ) + + assert info.language == "en" + assert info.language_probability > 0.7 + assert info.duration == 11 + + segments = list(result) + assert len(segments) == 3 + + for segment in segments: + assert "Americans" in segment.text or "country" in segment.text + + segments, info = batched_model.transcribe(jfk_path) + assert info.language == "en" + segments = list(segments) + assert len(segments) >= 1 + + +def test_transcribe_multiple_audios_with_word_timestamps(jfk_path): + """Test transcribe() with multiple audios and word timestamps.""" + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + + result, info = batched_model.transcribe( + [jfk_path, jfk_path], + batch_size=8, + word_timestamps=True, + without_timestamps=False, + ) + + assert info.language == "en" + + segments = list(result) + assert len(segments) >= 2 + + for segment in segments: + assert segment.words is not None + assert len(segment.words) > 0 + + for word in segment.words: + assert word.start is not None + assert word.end is not None + assert word.word is not None + + def test_cliptimestamps_segments(jfk_path): model = WhisperModel("tiny") pipeline = BatchedInferencePipeline(model=model)