Skip to content

Commit 4017efc

Browse files
authored
Merge pull request #1356 from m-bain/revert-1355-batch_wrap
Revert "Batch wrap"
2 parents 422c44f + 0e073d4 commit 4017efc

File tree

4 files changed

+18
-268
lines changed

4 files changed

+18
-268
lines changed

whisperx/__main__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def cli():
6060
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
6161
parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms")
6262
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
63-
parser.add_argument("--batch_context", action="store_true", help="use previous batch's transcription as context for the next batch (slower but more coherent across batches)")
6463
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
6564

6665
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")

whisperx/asr.py

Lines changed: 18 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
logger = get_logger(__name__)
2020

21+
2122
def find_numeral_symbol_tokens(tokenizer):
2223
numeral_symbol_tokens = []
2324
for i in range(tokenizer.eot):
@@ -39,48 +40,28 @@ def generate_segment_batched(
3940
tokenizer: Tokenizer,
4041
options: TranscriptionOptions,
4142
encoder_output=None,
42-
use_batch_context: bool = False,
43-
previous_batch_context_tokens: List[List[int]] = None,
4443
):
4544
batch_size = features.shape[0]
46-
if previous_batch_context_tokens is None:
47-
previous_batch_context_tokens = [[] for _ in range(batch_size)]
48-
49-
initial_prompt_tokens = []
45+
all_tokens = []
46+
prompt_reset_since = 0
5047
if options.initial_prompt is not None:
5148
initial_prompt = " " + options.initial_prompt.strip()
5249
initial_prompt_tokens = tokenizer.encode(initial_prompt)
53-
54-
batch_tokens = []
55-
for i in range(batch_size):
56-
all_tokens = list(initial_prompt_tokens)
57-
if use_batch_context:
58-
if i < len(previous_batch_context_tokens):
59-
ctx = previous_batch_context_tokens[i]
60-
if ctx:
61-
# 223 is max prompt tokens
62-
available = 223 - len(all_tokens)
63-
if available > 0:
64-
all_tokens.extend(ctx[-available:])
65-
batch_tokens.append(all_tokens)
66-
67-
max_batch_tokens = max([len(t) for t in batch_tokens] + [0])
68-
69-
prompts = [
70-
self.get_prompt(
71-
tokenizer,
72-
[tokenizer.eot] * (max_batch_tokens - len(t)) + t,
73-
without_timestamps=options.without_timestamps,
74-
prefix=options.prefix,
75-
hotwords=options.hotwords
76-
) for t in batch_tokens
77-
]
50+
all_tokens.extend(initial_prompt_tokens)
51+
previous_tokens = all_tokens[prompt_reset_since:]
52+
prompt = self.get_prompt(
53+
tokenizer,
54+
previous_tokens,
55+
without_timestamps=options.without_timestamps,
56+
prefix=options.prefix,
57+
hotwords=options.hotwords
58+
)
7859

7960
encoder_output = self.encode(features)
8061

8162
result = self.model.generate(
8263
encoder_output,
83-
prompts,
64+
[prompt] * batch_size,
8465
beam_size=options.beam_size,
8566
patience=options.patience,
8667
length_penalty=options.length_penalty,
@@ -101,9 +82,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]:
10182
return tokenizer.tokenizer.decode_batch(res)
10283

10384
text = decode_batch(tokens_batch)
85+
10486
return text
10587

106-
10788
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
10889
# When the model is running on multiple GPUs, the encoder output should be moved
10990
# to the CPU since we don't know which GPU will handle the next job.
@@ -134,15 +115,13 @@ def __init__(
134115
framework="pt",
135116
language: Optional[str] = None,
136117
suppress_numerals: bool = False,
137-
use_batch_context: bool = False,
138118
**kwargs,
139119
):
140120
self.model = model
141121
self.tokenizer = tokenizer
142122
self.options = options
143123
self.preset_language = language
144124
self.suppress_numerals = suppress_numerals
145-
self.use_batch_context = use_batch_context
146125
self._batch_size = kwargs.pop("batch_size", None)
147126
self._num_workers = 1
148127
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
@@ -163,8 +142,6 @@ def __init__(
163142
super(Pipeline, self).__init__()
164143
self.vad_model = vad
165144
self._vad_params = vad_params
166-
self.previous_batch_context_tokens = []
167-
168145

169146
def _sanitize_parameters(self, **kwargs):
170147
preprocess_kwargs = {}
@@ -183,35 +160,7 @@ def preprocess(self, audio):
183160
return {'inputs': features}
184161

185162
def _forward(self, model_inputs):
186-
current_batch_size = model_inputs['inputs'].shape[0]
187-
# Ideally, batch[i] corresponds to stream[i].
188-
# This holds if batch_size == number of streams.
189-
valid_contexts = self.previous_batch_context_tokens[:current_batch_size]
190-
191-
outputs = self.model.generate_segment_batched(
192-
model_inputs['inputs'],
193-
self.tokenizer,
194-
self.options,
195-
use_batch_context=self.use_batch_context,
196-
previous_batch_context_tokens=valid_contexts,
197-
)
198-
if self.use_batch_context:
199-
initial_prompt_length = 0
200-
if self.options.initial_prompt is not None:
201-
initial_prompt = " " + self.options.initial_prompt.strip()
202-
initial_prompt_length = len(self.tokenizer.encode(initial_prompt))
203-
204-
# Use 220 instead of 224 to be safe
205-
max_context_window = max(0, 220 - initial_prompt_length)
206-
207-
for i, text in enumerate(outputs):
208-
if i < len(self.previous_batch_context_tokens):
209-
# Filter out special tokens (timestamps, SOT, EOT, etc.)
210-
# We only want the text content for context.
211-
tokens = [t for t in self.tokenizer.encode(text) if t < self.tokenizer.eot]
212-
self.previous_batch_context_tokens[i].extend(tokens)
213-
self.previous_batch_context_tokens[i] = self.previous_batch_context_tokens[i][-max_context_window:]
214-
163+
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
215164
return {'text': outputs}
216165

217166
def postprocess(self, model_outputs):
@@ -252,14 +201,6 @@ def transcribe(
252201
) -> TranscriptionResult:
253202
if isinstance(audio, str):
254203
audio = load_audio(audio)
255-
256-
batch_size = batch_size or self._batch_size
257-
# Initialize context for each stream.
258-
# We have 'batch_size' concurrent streams.
259-
if batch_size is None or batch_size < 1:
260-
batch_size = 1
261-
262-
self.previous_batch_context_tokens = [[] for _ in range(batch_size)]
263204

264205
def data(audio, segments):
265206
for seg in segments:
@@ -311,33 +252,10 @@ def data(audio, segments):
311252
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
312253
new_suppressed_tokens = list(set(new_suppressed_tokens))
313254
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
314-
255+
315256
segments: List[SingleSegment] = []
316257
batch_size = batch_size or self._batch_size
317258
total_segments = len(vad_segments)
318-
319-
if batch_size > 1 and self.use_batch_context:
320-
num_streams = batch_size
321-
# Distribute segments into streams
322-
# Manual split
323-
k, m = divmod(len(vad_segments), num_streams)
324-
# lengths of each part: first m parts have k+1, rest have k
325-
stream_segments = []
326-
start_idx = 0
327-
for i in range(num_streams):
328-
part_len = k + 1 if i < m else k
329-
stream_segments.append(vad_segments[start_idx : start_idx + part_len])
330-
start_idx += part_len
331-
# Interleave
332-
# We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...]
333-
interleaved_segments = []
334-
max_len = max(len(s) for s in stream_segments)
335-
for i in range(max_len):
336-
for stream in stream_segments:
337-
if i < len(stream):
338-
interleaved_segments.append(stream[i])
339-
vad_segments = interleaved_segments
340-
341259
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
342260
if print_progress:
343261
base_progress = ((idx + 1) / total_segments) * 100
@@ -356,25 +274,6 @@ def data(audio, segments):
356274
}
357275
)
358276

359-
if self.use_batch_context and batch_size > 1:
360-
last_stream_index = (total_segments - 1) % batch_size
361-
final_context = self.previous_batch_context_tokens[last_stream_index]
362-
# Prepare context for the wrap-around re-run
363-
# ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file).
364-
# All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N).
365-
new_rerun_context = [[] for _ in range(batch_size)]
366-
new_rerun_context[0] = final_context
367-
# Temporarily overwrite previous_batch_context_tokens for the re-run
368-
self.previous_batch_context_tokens = new_rerun_context
369-
first_batch_segments = vad_segments[:batch_size]
370-
# Runs the model again just on 'first_batch_segments'
371-
for i, out in enumerate(self.__call__(data(audio, first_batch_segments), batch_size=batch_size, num_workers=num_workers)):
372-
text = out['text']
373-
# L398: Overwrite the existing text with the new wrap-around text
374-
segments[i]['text'] = text
375-
# Sort segments by start time to restore original order
376-
segments.sort(key=lambda x: x['start'])
377-
378277
# revert the tokenizer if multilingual inference is enabled
379278
if self.preset_language is None:
380279
self.tokenizer = None
@@ -390,8 +289,8 @@ def detect_language(self, audio: np.ndarray) -> str:
390289
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
391290
model_n_mels = self.model.feat_kwargs.get("feature_size")
392291
segment = log_mel_spectrogram(audio[: N_SAMPLES],
393-
n_mels=model_n_mels if model_n_mels is not None else 80,
394-
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
292+
n_mels=model_n_mels if model_n_mels is not None else 80,
293+
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
395294
encoder_output = self.model.encode(segment)
396295
results = self.model.model.detect_language(encoder_output)
397296
language_token, language_probability = results[0][0]
@@ -416,7 +315,6 @@ def load_model(
416315
local_files_only=False,
417316
threads=4,
418317
use_auth_token: Optional[Union[str, bool]] = None,
419-
use_batch_context: bool = False,
420318
) -> FasterWhisperPipeline:
421319
"""Load a Whisper model for inference.
422320
Args:
@@ -523,5 +421,4 @@ def load_model(
523421
language=language,
524422
suppress_numerals=suppress_numerals,
525423
vad_params=default_vad_options,
526-
use_batch_context=use_batch_context,
527424
)

whisperx/benchmark.py

Lines changed: 0 additions & 144 deletions
This file was deleted.

0 commit comments

Comments
 (0)