diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index a73b31b5..8c5dfe33 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -106,6 +106,7 @@ class TranscriptionInfo: all_language_probs: Optional[List[Tuple[str, float]]] transcription_options: TranscriptionOptions vad_options: VadOptions + num_chunks: Optional[List[int]] = None class BatchedInferencePipeline: @@ -251,6 +252,47 @@ def generate_segment_batched( return encoder_output, output + def batch_audio_files( + self, + audio_files: Union[str, list[str]], + sampling_rate: int = 16000, + batch_size: int = 8, + ) -> Iterable[Tuple[np.ndarray, list[int]]]: + """Batch multiple audio files together similar to PyTorch's collate_fn. + + Arguments: + audio_files: Path to a single audio file (or a file-like object), or a list of files. + sampling_rate: Resample the audio to this sample rate. + batch_size: Max size of a single batch. + + Returns: + A generator of numpy arrays with size (batch_size, X) where + X is the max audio size in the batch. + + A generator of non-padded valid lengths of each audio file in the batch + """ + + if not isinstance(audio_files, list): + audio_files = [audio_files] + + for i in range(0, len(audio_files), batch_size): + batch_audio_files = audio_files[i : i + batch_size] + + audios = [] + lens = [] + max_len_in_batch = 0 + for audio_file in batch_audio_files: + audio = decode_audio(audio_file, sampling_rate=sampling_rate) + max_len_in_batch = max(max_len_in_batch, len(audio)) + lens.append(len(audio)) + audios.append(audio) + + batched_audios = np.stack( + [pad_or_trim(audio, length=max_len_in_batch) for audio in audios] + ) + + yield (batched_audios, lens) + def transcribe( self, audio: Union[str, BinaryIO, np.ndarray], @@ -341,7 +383,7 @@ def transcribe( clip_timestamps: Optionally provide list of dictionaries each containing "start" and "end" keys that specify the start and end of the voiced region within `chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used. - batch_size: the maximum number of parallel requests to model for decoding. + batch_size: the maximum number of audio files to process in parallel. hotwords: Hotwords/hint phrases to the model. Has no effect if prefix is not None. language_detection_threshold: If the maximum probability of the language tokens is @@ -383,16 +425,7 @@ def transcribe( ) multilingual = False - if not isinstance(audio, np.ndarray): - audio = decode_audio(audio, sampling_rate=sampling_rate) - duration = audio.shape[0] / sampling_rate - - self.model.logger.info( - "Processing audio with duration %s", format_timestamp(duration) - ) - chunk_length = chunk_length or self.model.feature_extractor.chunk_length - # if no segment split is provided, use vad_model and generate segments if not clip_timestamps: if vad_filter: if vad_parameters is None: @@ -407,92 +440,16 @@ def transcribe( vad_parameters = VadOptions( **vad_parameters, max_speech_duration_s=chunk_length ) - - clip_timestamps = get_speech_timestamps(audio, vad_parameters) - # run the audio if it is less than 30 sec even without clip_timestamps - elif duration < chunk_length: - clip_timestamps = [{"start": 0, "end": audio.shape[0]}] - else: - raise RuntimeError( - "No clip timestamps found. " - "Set 'vad_filter' to True or provide 'clip_timestamps'." - ) - - audio_chunks, chunks_metadata = collect_chunks( - audio, clip_timestamps, max_duration=chunk_length - ) - else: + raise NotImplementedError("if clip_timestamps is provided") clip_timestamps = [ {k: int(v * sampling_rate) for k, v in segment.items()} for segment in clip_timestamps ] - audio_chunks, chunks_metadata = [], [] - for clip in clip_timestamps: - audio_chunks.append(audio[clip["start"] : clip["end"]]) - chunks_metadata.append( - { - "offset": clip["start"] / sampling_rate, - "duration": (clip["end"] - clip["start"]) / sampling_rate, - "segments": [clip], - } - ) - - duration_after_vad = ( - sum((segment["end"] - segment["start"]) for segment in clip_timestamps) - / sampling_rate - ) - - self.model.logger.info( - "VAD filter removed %s of audio", - format_timestamp(duration - duration_after_vad), - ) - - features = ( - [self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks] - if duration_after_vad - else [] - ) - - all_language_probs = None - # detecting the language if not provided - if language is None: - if not self.model.model.is_multilingual: - language = "en" - language_probability = 1 - else: - ( - language, - language_probability, - all_language_probs, - ) = self.model.detect_language( - features=np.concatenate( - features - + [ - np.full((self.model.model.n_mels, 1), -1.5, dtype="float32") - ], - axis=1, - ), # add a dummy feature to account for empty audio - language_detection_segments=language_detection_segments, - language_detection_threshold=language_detection_threshold, - ) - - self.model.logger.info( - "Detected language '%s' with probability %.2f", - language, - language_probability, - ) - else: - if not self.model.model.is_multilingual and language != "en": - self.model.logger.warning( - "The current model is English-only but the language parameter is set to '%s'; " - "using 'en' instead." % language - ) - language = "en" - - language_probability = 1 - + # if this was inside the batch loop, we would allow for + # batched files to be different languages. for now + # let's assume all files are in the same language tokenizer = Tokenizer( self.model.hf_tokenizer, self.model.model.is_multilingual, @@ -500,10 +457,6 @@ def transcribe( language=language, ) - features = ( - np.stack([pad_or_trim(feature) for feature in features]) if features else [] - ) - options = TranscriptionOptions( beam_size=beam_size, best_of=best_of, @@ -534,34 +487,215 @@ def transcribe( word_timestamps=word_timestamps, hallucination_silence_threshold=None, condition_on_previous_text=False, - clip_timestamps=clip_timestamps, + clip_timestamps=None, # this field is not used in the BatchedInferencePipeline prompt_reset_on_temperature=0.5, multilingual=multilingual, without_timestamps=without_timestamps, max_initial_timestamp=0.0, ) - info = TranscriptionInfo( - language=language, - language_probability=language_probability, - duration=duration, - duration_after_vad=duration_after_vad, - transcription_options=options, - vad_options=vad_parameters, - all_language_probs=all_language_probs, - ) + if not isinstance(audio, np.ndarray): + batch_generator = self.batch_audio_files( + audio, batch_size=batch_size, sampling_rate=sampling_rate + ) + else: + if audio.ndim != 2: + raise ValueError( + "Input audio must have a single batch dimension if provided as numpy array" + ) + batch_generator = audio - segments = self._batched_segments_generator( - features, - tokenizer, - chunks_metadata, - batch_size, - options, - log_progress, - ) - segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate) + batched_segments = [] + batched_info = [] + for audio, valid_lens in batch_generator: + curr_batch_size = audio.shape[0] + batch_duration = audio.shape[-1] / sampling_rate + batched_durations = [valid_len / sampling_rate for valid_len in valid_lens] - return segments, info + for duration in batched_durations: + self.model.logger.info( + "Processing audio with duration %s", format_timestamp(duration) + ) + + # if no segment split is provided, use vad_model and generate segments + if not clip_timestamps: + if vad_filter: + batched_clip_timestamps = get_speech_timestamps( + audio, vad_parameters + ) + # run the audio if it is less than 30 sec even without clip_timestamps + elif batch_duration < chunk_length: + batched_clip_timestamps = [ + {"start": 0, "end": valid_len} for valid_len in valid_lens + ] + else: + raise RuntimeError( + "No clip timestamps found. " + "Set 'vad_filter' to True or provide 'clip_timestamps'." + ) + + batched_audio_chunks = [] + batched_chunks_metadata = [] + + for i, clip in enumerate(audio): + audio_chunks, chunks_metadata = collect_chunks( + clip, batched_clip_timestamps[i], max_duration=chunk_length + ) + batched_audio_chunks.append(audio_chunks) + batched_chunks_metadata.append(chunks_metadata) + + else: + raise NotImplementedError( + "if clip_timestamps is provided in batch loop" + ) + audio_chunks, chunks_metadata = [], [] + for clip in clip_timestamps: + audio_chunks.append(audio[clip["start"] : clip["end"]]) + chunks_metadata.append( + { + "offset": clip["start"] / sampling_rate, + "duration": (clip["end"] - clip["start"]) / sampling_rate, + "segments": [clip], + } + ) + + batched_features = [] + batched_durations_after_vad = [] + + for i in range(curr_batch_size): + batched_durations_after_vad.append( + sum( + (segment["end"] - segment["start"]) + for segment in batched_clip_timestamps[i] + ) + / sampling_rate + ) + + self.model.logger.info( + "VAD filter removed %s of audio", + format_timestamp( + batched_durations[i] - batched_durations_after_vad[i] + ), + ) + + batched_features.append( + [ + self.model.feature_extractor(chunk)[..., :-1] + for chunk in batched_audio_chunks[i] + ] + if batched_durations_after_vad[i] + else [] + ) + + all_language_probs = None + # detecting the language if not provided + if language is None: + raise NotImplementedError("when language is not provided") + if not self.model.model.is_multilingual: + language = "en" + language_probability = 1 + else: + ( + language, + language_probability, + all_language_probs, + ) = self.model.detect_language( + features=np.concatenate( + batched_features + + [ + np.full( + (self.model.model.n_mels, 1), -1.5, dtype="float32" + ) + ], + axis=1, + ), # add a dummy feature to account for empty audio + language_detection_segments=language_detection_segments, + language_detection_threshold=language_detection_threshold, + ) + + self.model.logger.info( + "Detected language '%s' with probability %.2f", + language, + language_probability, + ) + else: + print("TODO: can move this outside of batch loop") + if not self.model.model.is_multilingual and language != "en": + self.model.logger.warning( + "The current model is English-only but the language parameter is set to '%s'; " + "using 'en' instead." % language + ) + language = "en" + + language_probability = 1 + + # flattening to have a single batch dimension (align with numpy C-order) + # we keep a single batch dimensions using vertical concat, despite varying number of chunks + # per audio file. to keep track, we refer to batched_chunks_metadata to see + # which chunks correspond to which audio file (number of chunks per audio file) + flat_padded_batched_features = [ + ( + np.stack([pad_or_trim(feature) for feature in features]) + if features + else [] + ) + for features in batched_features + ] + flat_padded_batched_features = np.concatenate(flat_padded_batched_features) + + flat_chunks_metadata = [ + chunk + for file_chunks in batched_chunks_metadata + for chunk in file_chunks + ] + + flat_batched_clip_timestamps = [ + clip + for clip_timestamps in batched_clip_timestamps + for clip in clip_timestamps + ] + + # we regroup within outside of this function (user does it) + # this is because segments is an iterator and thus we can't + # regroup based on number of chunks in metadata + flat_segments = self._batched_segments_generator( + flat_padded_batched_features, + tokenizer, + flat_chunks_metadata, + batch_size, + options, + log_progress, + ) + + flat_segments = restore_speech_timestamps( + flat_segments, flat_batched_clip_timestamps, sampling_rate + ) + + # regrouping after flattening + info = [] + for i in range(curr_batch_size): + # this option is not used by _batched_segments_generator + options.clip_timestamps = batched_clip_timestamps[i] + + info.append( + TranscriptionInfo( + language=language, + language_probability=language_probability, + duration=batched_durations[i], + duration_after_vad=batched_durations_after_vad[i], + transcription_options=options, + vad_options=vad_parameters, + all_language_probs=all_language_probs, + # so that user knows how many segments to consume for each audio file + # when iterating over flat_segments + num_chunks=len(batched_chunks_metadata[i]), + ) + ) + + batched_segments.append(flat_segments) + batched_info.append(info) + + return batched_segments, batched_info def _batched_segments_generator( self, features, tokenizer, chunks_metadata, batch_size, options, log_progress diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index cc42f371..7ee7d644 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -79,108 +79,116 @@ def get_speech_timestamps( min_silence_samples = sampling_rate * min_silence_duration_ms / 1000 min_silence_samples_at_max_speech = sampling_rate * 98 / 1000 - audio_length_samples = len(audio) + audio_length_samples = audio.shape[-1] model = get_vad_model() padded_audio = np.pad( - audio, (0, window_size_samples - audio.shape[0] % window_size_samples) + audio, + ((0, 0), (0, window_size_samples - audio.shape[-1] % window_size_samples)), ) - speech_probs = model(padded_audio.reshape(1, -1)).squeeze(0) - - triggered = False - speeches = [] - current_speech = {} - if neg_threshold is None: - neg_threshold = max(threshold - 0.15, 0.01) - - # to save potential segment end (and tolerate some silence) - temp_end = 0 - # to save potential segment limits in case of maximum segment size reached - prev_end = next_start = 0 - - for i, speech_prob in enumerate(speech_probs): - if (speech_prob >= threshold) and temp_end: - temp_end = 0 - if next_start < prev_end: - next_start = window_size_samples * i - - if (speech_prob >= threshold) and not triggered: - triggered = True - current_speech["start"] = window_size_samples * i - continue - - if ( - triggered - and (window_size_samples * i) - current_speech["start"] > max_speech_samples - ): - if prev_end: - current_speech["end"] = prev_end - speeches.append(current_speech) - current_speech = {} - # previously reached silence (< neg_thres) and is still not speech (< thres) + batched_speech_probs = model(padded_audio).squeeze(-1) + + batched_speeches = [] + for speech_probs in batched_speech_probs: + triggered = False + speeches = [] + current_speech = {} + if neg_threshold is None: + neg_threshold = max(threshold - 0.15, 0.01) + + # to save potential segment end (and tolerate some silence) + temp_end = 0 + # to save potential segment limits in case of maximum segment size reached + prev_end = next_start = 0 + + for i, speech_prob in enumerate(speech_probs): + if (speech_prob >= threshold) and temp_end: + temp_end = 0 if next_start < prev_end: - triggered = False - else: - current_speech["start"] = next_start - prev_end = next_start = temp_end = 0 - else: - current_speech["end"] = window_size_samples * i - speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue + next_start = window_size_samples * i - if (speech_prob < neg_threshold) and triggered: - if not temp_end: - temp_end = window_size_samples * i - # condition to avoid cutting in very short silence - if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech: - prev_end = temp_end - if (window_size_samples * i) - temp_end < min_silence_samples: + if (speech_prob >= threshold) and not triggered: + triggered = True + current_speech["start"] = window_size_samples * i continue - else: - current_speech["end"] = temp_end - if ( - current_speech["end"] - current_speech["start"] - ) > min_speech_samples: + + if ( + triggered + and (window_size_samples * i) - current_speech["start"] + > max_speech_samples + ): + if prev_end: + current_speech["end"] = prev_end speeches.append(current_speech) - current_speech = {} - prev_end = next_start = temp_end = 0 - triggered = False - continue + current_speech = {} + # previously reached silence (< neg_thres) and is still not speech (< thres) + if next_start < prev_end: + triggered = False + else: + current_speech["start"] = next_start + prev_end = next_start = temp_end = 0 + else: + current_speech["end"] = window_size_samples * i + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue - if ( - current_speech - and (audio_length_samples - current_speech["start"]) > min_speech_samples - ): - current_speech["end"] = audio_length_samples - speeches.append(current_speech) - - for i, speech in enumerate(speeches): - if i == 0: - speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) - if i != len(speeches) - 1: - silence_duration = speeches[i + 1]["start"] - speech["end"] - if silence_duration < 2 * speech_pad_samples: - speech["end"] += int(silence_duration // 2) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - silence_duration // 2) - ) + if (speech_prob < neg_threshold) and triggered: + if not temp_end: + temp_end = window_size_samples * i + # condition to avoid cutting in very short silence + if ( + window_size_samples * i + ) - temp_end > min_silence_samples_at_max_speech: + prev_end = temp_end + if (window_size_samples * i) - temp_end < min_silence_samples: + continue + else: + current_speech["end"] = temp_end + if ( + current_speech["end"] - current_speech["start"] + ) > min_speech_samples: + speeches.append(current_speech) + current_speech = {} + prev_end = next_start = temp_end = 0 + triggered = False + continue + + if ( + current_speech + and (audio_length_samples - current_speech["start"]) > min_speech_samples + ): + current_speech["end"] = audio_length_samples + speeches.append(current_speech) + + for i, speech in enumerate(speeches): + if i == 0: + speech["start"] = int(max(0, speech["start"] - speech_pad_samples)) + if i != len(speeches) - 1: + silence_duration = speeches[i + 1]["start"] - speech["end"] + if silence_duration < 2 * speech_pad_samples: + speech["end"] += int(silence_duration // 2) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - silence_duration // 2) + ) + else: + speech["end"] = int( + min(audio_length_samples, speech["end"] + speech_pad_samples) + ) + speeches[i + 1]["start"] = int( + max(0, speeches[i + 1]["start"] - speech_pad_samples) + ) else: speech["end"] = int( min(audio_length_samples, speech["end"] + speech_pad_samples) ) - speeches[i + 1]["start"] = int( - max(0, speeches[i + 1]["start"] - speech_pad_samples) - ) - else: - speech["end"] = int( - min(audio_length_samples, speech["end"] + speech_pad_samples) - ) - return speeches + batched_speeches.append(speeches) + + return batched_speeches def collect_chunks( @@ -326,7 +334,7 @@ def __call__( audio.ndim == 2 ), "Input should be a 2D array with size (batch_size, num_samples)" assert ( - audio.shape[1] % num_samples == 0 + audio.shape[-1] % num_samples == 0 ), "Input size should be a multiple of num_samples" batch_size = audio.shape[0] diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 48b409eb..73eedbf7 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -59,14 +59,126 @@ def test_transcribe(jfk_path): ) -def test_batched_transcribe(physcisworks_path): +def test_batch_audio_files(data_dir): + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + + hotwords = os.path.join(data_dir, "hotwords.mp3") + hotwords_audio = decode_audio(hotwords) + + jfk = os.path.join(data_dir, "jfk.flac") + jfk_audio = decode_audio(jfk) + + audio_files = [ + hotwords, # batch 0 + hotwords, + jfk, # batch 1 + jfk, + jfk, # batch 2 + hotwords, + hotwords, # batch 3 + ] + + batch_size = 2 + batch_generator = batched_model.batch_audio_files( + audio_files, batch_size=batch_size + ) + + for i, (audios, valid_lens) in enumerate(batch_generator): + if i < 3: + # batches 0-2 are filled batches + assert audios.shape[0] == 2 + + if i == 0: + assert audios.shape[-1] == hotwords_audio.shape[-1] + + # in batch 2, we should pad to jfk size since len(jfk) > len(hotwords) + elif i == 1 or i == 2: + assert audios.shape[-1] == jfk_audio.shape[-1] + + # make sure we can recover original lengths + if i == 2: + assert valid_lens[0] == jfk_audio.shape[-1] + assert valid_lens[1] == hotwords_audio.shape[-1] + + else: + # only batch 3 is an unfilled batch + assert audios.shape[0] == 1 + assert audios.shape[-1] == hotwords_audio.shape[-1] + + +def test_batched_transcribe_many(jfk_path, physcisworks_path): + model = WhisperModel("tiny") + batched_model = BatchedInferencePipeline(model=model) + + physcisworks_audio = decode_audio(physcisworks_path) + jfk_audio = decode_audio(jfk_path) + + audio_files = [ + physcisworks_path, # batch 0 + jfk_path, + jfk_path, # batch 1 + jfk_path, + physcisworks_path, # batch 2 + physcisworks_path, + physcisworks_path, # batch 3 + ] + + batch_size = 2 + num_batches = (len(audio_files) // batch_size) + 1 + + batch_segments, batch_info = batched_model.transcribe( + audio_files, batch_size=batch_size, language="en" + ) + + # transcribe returns a list of generators with size equal to number of batches + # iterate through each batch and then through each generator in that batch + # to get a flat list of segments (processed in parallel) + # then recreate hierarchy by stacking chunks for each audio file + regrouped_segments = [] + total_flat_segments = [] + for batch_idx in range(num_batches): + info = batch_info[batch_idx] + + flat_segments = [] + for segment in batch_segments[batch_idx]: + flat_segments.append(segment) + total_flat_segments.append(segment) + + chunk_idx = 0 + for audio_info in info: + num_chunks = audio_info.num_chunks + regrouped_segments.append(flat_segments[chunk_idx : chunk_idx + num_chunks]) + chunk_idx += num_chunks + + num_jfk_files = 3 + num_physics_files = 4 + expected_num_chunks_jfk = 1 * num_jfk_files + expected_num_chunks_physics = 6 * num_physics_files + expected_total_chunks = expected_num_chunks_jfk + expected_num_chunks_physics + + assert len(total_flat_segments) == expected_total_chunks + assert len(regrouped_segments) == len(audio_files) + + for i in range(1, 4): + # because jfk only has one segment + assert regrouped_segments[i][0].text == ( + " And so my fellow Americans ask not what your country can do for you, " + "ask what you can do for your country." + ) + + # TODO: assert result for each other flat segments are identical to non-batched result + + +def test_batched_transcribe_one(physcisworks_path): model = WhisperModel("tiny") batched_model = BatchedInferencePipeline(model=model) result, info = batched_model.transcribe(physcisworks_path, batch_size=16) - assert info.language == "en" - assert info.language_probability > 0.7 + assert info[0][0].language == "en" + assert info[0][0].language_probability > 0.7 segments = [] - for segment in result: + + for segment in result[0]: segments.append( {"start": segment.start, "end": segment.end, "text": segment.text} ) @@ -80,7 +192,7 @@ def test_batched_transcribe(physcisworks_path): word_timestamps=True, ) segments = [] - for segment in result: + for segment in result[0]: assert segment.words is not None segments.append( {"start": segment.start, "end": segment.end, "text": segment.text}