Skip to content

Commit 422c44f

Browse files
authored
Merge pull request #1355 from 1carlito/batch_wrap
Batch wrap
2 parents 42beab1 + 1b6a3b7 commit 422c44f

File tree

4 files changed

+268
-18
lines changed

4 files changed

+268
-18
lines changed

whisperx/__main__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def cli():
6060
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
6161
parser.add_argument("--hotwords", type=str, default=None, help="hotwords/hint phrases to the model (e.g. \"WhisperX, PyAnnote, GPU\"); improves recognition of rare/technical terms")
6262
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
63+
parser.add_argument("--batch_context", action="store_true", help="use previous batch's transcription as context for the next batch (slower but more coherent across batches)")
6364
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
6465

6566
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")

whisperx/asr.py

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
logger = get_logger(__name__)
2020

21-
2221
def find_numeral_symbol_tokens(tokenizer):
2322
numeral_symbol_tokens = []
2423
for i in range(tokenizer.eot):
@@ -40,28 +39,48 @@ def generate_segment_batched(
4039
tokenizer: Tokenizer,
4140
options: TranscriptionOptions,
4241
encoder_output=None,
42+
use_batch_context: bool = False,
43+
previous_batch_context_tokens: List[List[int]] = None,
4344
):
4445
batch_size = features.shape[0]
45-
all_tokens = []
46-
prompt_reset_since = 0
46+
if previous_batch_context_tokens is None:
47+
previous_batch_context_tokens = [[] for _ in range(batch_size)]
48+
49+
initial_prompt_tokens = []
4750
if options.initial_prompt is not None:
4851
initial_prompt = " " + options.initial_prompt.strip()
4952
initial_prompt_tokens = tokenizer.encode(initial_prompt)
50-
all_tokens.extend(initial_prompt_tokens)
51-
previous_tokens = all_tokens[prompt_reset_since:]
52-
prompt = self.get_prompt(
53-
tokenizer,
54-
previous_tokens,
55-
without_timestamps=options.without_timestamps,
56-
prefix=options.prefix,
57-
hotwords=options.hotwords
58-
)
53+
54+
batch_tokens = []
55+
for i in range(batch_size):
56+
all_tokens = list(initial_prompt_tokens)
57+
if use_batch_context:
58+
if i < len(previous_batch_context_tokens):
59+
ctx = previous_batch_context_tokens[i]
60+
if ctx:
61+
# 223 is max prompt tokens
62+
available = 223 - len(all_tokens)
63+
if available > 0:
64+
all_tokens.extend(ctx[-available:])
65+
batch_tokens.append(all_tokens)
66+
67+
max_batch_tokens = max([len(t) for t in batch_tokens] + [0])
68+
69+
prompts = [
70+
self.get_prompt(
71+
tokenizer,
72+
[tokenizer.eot] * (max_batch_tokens - len(t)) + t,
73+
without_timestamps=options.without_timestamps,
74+
prefix=options.prefix,
75+
hotwords=options.hotwords
76+
) for t in batch_tokens
77+
]
5978

6079
encoder_output = self.encode(features)
6180

6281
result = self.model.generate(
6382
encoder_output,
64-
[prompt] * batch_size,
83+
prompts,
6584
beam_size=options.beam_size,
6685
patience=options.patience,
6786
length_penalty=options.length_penalty,
@@ -82,9 +101,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]:
82101
return tokenizer.tokenizer.decode_batch(res)
83102

84103
text = decode_batch(tokens_batch)
85-
86104
return text
87105

106+
88107
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
89108
# When the model is running on multiple GPUs, the encoder output should be moved
90109
# to the CPU since we don't know which GPU will handle the next job.
@@ -115,13 +134,15 @@ def __init__(
115134
framework="pt",
116135
language: Optional[str] = None,
117136
suppress_numerals: bool = False,
137+
use_batch_context: bool = False,
118138
**kwargs,
119139
):
120140
self.model = model
121141
self.tokenizer = tokenizer
122142
self.options = options
123143
self.preset_language = language
124144
self.suppress_numerals = suppress_numerals
145+
self.use_batch_context = use_batch_context
125146
self._batch_size = kwargs.pop("batch_size", None)
126147
self._num_workers = 1
127148
self._preprocess_params, self._forward_params, self._postprocess_params = self._sanitize_parameters(**kwargs)
@@ -142,6 +163,8 @@ def __init__(
142163
super(Pipeline, self).__init__()
143164
self.vad_model = vad
144165
self._vad_params = vad_params
166+
self.previous_batch_context_tokens = []
167+
145168

146169
def _sanitize_parameters(self, **kwargs):
147170
preprocess_kwargs = {}
@@ -160,7 +183,35 @@ def preprocess(self, audio):
160183
return {'inputs': features}
161184

162185
def _forward(self, model_inputs):
163-
outputs = self.model.generate_segment_batched(model_inputs['inputs'], self.tokenizer, self.options)
186+
current_batch_size = model_inputs['inputs'].shape[0]
187+
# Ideally, batch[i] corresponds to stream[i].
188+
# This holds if batch_size == number of streams.
189+
valid_contexts = self.previous_batch_context_tokens[:current_batch_size]
190+
191+
outputs = self.model.generate_segment_batched(
192+
model_inputs['inputs'],
193+
self.tokenizer,
194+
self.options,
195+
use_batch_context=self.use_batch_context,
196+
previous_batch_context_tokens=valid_contexts,
197+
)
198+
if self.use_batch_context:
199+
initial_prompt_length = 0
200+
if self.options.initial_prompt is not None:
201+
initial_prompt = " " + self.options.initial_prompt.strip()
202+
initial_prompt_length = len(self.tokenizer.encode(initial_prompt))
203+
204+
# Use 220 instead of 224 to be safe
205+
max_context_window = max(0, 220 - initial_prompt_length)
206+
207+
for i, text in enumerate(outputs):
208+
if i < len(self.previous_batch_context_tokens):
209+
# Filter out special tokens (timestamps, SOT, EOT, etc.)
210+
# We only want the text content for context.
211+
tokens = [t for t in self.tokenizer.encode(text) if t < self.tokenizer.eot]
212+
self.previous_batch_context_tokens[i].extend(tokens)
213+
self.previous_batch_context_tokens[i] = self.previous_batch_context_tokens[i][-max_context_window:]
214+
164215
return {'text': outputs}
165216

166217
def postprocess(self, model_outputs):
@@ -201,6 +252,14 @@ def transcribe(
201252
) -> TranscriptionResult:
202253
if isinstance(audio, str):
203254
audio = load_audio(audio)
255+
256+
batch_size = batch_size or self._batch_size
257+
# Initialize context for each stream.
258+
# We have 'batch_size' concurrent streams.
259+
if batch_size is None or batch_size < 1:
260+
batch_size = 1
261+
262+
self.previous_batch_context_tokens = [[] for _ in range(batch_size)]
204263

205264
def data(audio, segments):
206265
for seg in segments:
@@ -252,10 +311,33 @@ def data(audio, segments):
252311
new_suppressed_tokens = numeral_symbol_tokens + self.options.suppress_tokens
253312
new_suppressed_tokens = list(set(new_suppressed_tokens))
254313
self.options = replace(self.options, suppress_tokens=new_suppressed_tokens)
255-
314+
256315
segments: List[SingleSegment] = []
257316
batch_size = batch_size or self._batch_size
258317
total_segments = len(vad_segments)
318+
319+
if batch_size > 1 and self.use_batch_context:
320+
num_streams = batch_size
321+
# Distribute segments into streams
322+
# Manual split
323+
k, m = divmod(len(vad_segments), num_streams)
324+
# lengths of each part: first m parts have k+1, rest have k
325+
stream_segments = []
326+
start_idx = 0
327+
for i in range(num_streams):
328+
part_len = k + 1 if i < m else k
329+
stream_segments.append(vad_segments[start_idx : start_idx + part_len])
330+
start_idx += part_len
331+
# Interleave
332+
# We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...]
333+
interleaved_segments = []
334+
max_len = max(len(s) for s in stream_segments)
335+
for i in range(max_len):
336+
for stream in stream_segments:
337+
if i < len(stream):
338+
interleaved_segments.append(stream[i])
339+
vad_segments = interleaved_segments
340+
259341
for idx, out in enumerate(self.__call__(data(audio, vad_segments), batch_size=batch_size, num_workers=num_workers)):
260342
if print_progress:
261343
base_progress = ((idx + 1) / total_segments) * 100
@@ -274,6 +356,25 @@ def data(audio, segments):
274356
}
275357
)
276358

359+
if self.use_batch_context and batch_size > 1:
360+
last_stream_index = (total_segments - 1) % batch_size
361+
final_context = self.previous_batch_context_tokens[last_stream_index]
362+
# Prepare context for the wrap-around re-run
363+
# ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file).
364+
# All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N).
365+
new_rerun_context = [[] for _ in range(batch_size)]
366+
new_rerun_context[0] = final_context
367+
# Temporarily overwrite previous_batch_context_tokens for the re-run
368+
self.previous_batch_context_tokens = new_rerun_context
369+
first_batch_segments = vad_segments[:batch_size]
370+
# Runs the model again just on 'first_batch_segments'
371+
for i, out in enumerate(self.__call__(data(audio, first_batch_segments), batch_size=batch_size, num_workers=num_workers)):
372+
text = out['text']
373+
# L398: Overwrite the existing text with the new wrap-around text
374+
segments[i]['text'] = text
375+
# Sort segments by start time to restore original order
376+
segments.sort(key=lambda x: x['start'])
377+
277378
# revert the tokenizer if multilingual inference is enabled
278379
if self.preset_language is None:
279380
self.tokenizer = None
@@ -289,8 +390,8 @@ def detect_language(self, audio: np.ndarray) -> str:
289390
logger.warning("Audio is shorter than 30s, language detection may be inaccurate")
290391
model_n_mels = self.model.feat_kwargs.get("feature_size")
291392
segment = log_mel_spectrogram(audio[: N_SAMPLES],
292-
n_mels=model_n_mels if model_n_mels is not None else 80,
293-
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
393+
n_mels=model_n_mels if model_n_mels is not None else 80,
394+
padding=0 if audio.shape[0] >= N_SAMPLES else N_SAMPLES - audio.shape[0])
294395
encoder_output = self.model.encode(segment)
295396
results = self.model.model.detect_language(encoder_output)
296397
language_token, language_probability = results[0][0]
@@ -315,6 +416,7 @@ def load_model(
315416
local_files_only=False,
316417
threads=4,
317418
use_auth_token: Optional[Union[str, bool]] = None,
419+
use_batch_context: bool = False,
318420
) -> FasterWhisperPipeline:
319421
"""Load a Whisper model for inference.
320422
Args:
@@ -421,4 +523,5 @@ def load_model(
421523
language=language,
422524
suppress_numerals=suppress_numerals,
423525
vad_params=default_vad_options,
526+
use_batch_context=use_batch_context,
424527
)

whisperx/benchmark.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import argparse
2+
import os
3+
import time
4+
import torch
5+
import torchaudio
6+
import jiwer
7+
import whisperx
8+
import numpy as np
9+
from typing import Tuple
10+
11+
def load_tedlium(root: str, download: bool = False, subset: str = "test"):
12+
print(f"Loading TEDLIUM dataset ({subset}) from {root}...")
13+
try:
14+
dataset = torchaudio.datasets.TEDLIUM(
15+
root=root,
16+
release="release3",
17+
subset=subset,
18+
download=download
19+
)
20+
return dataset
21+
except Exception as e:
22+
print(f"Error loading dataset: {e}")
23+
return None
24+
25+
def normalize_text(text: str) -> str:
26+
"""
27+
Simple normalization: lower case, remove punctuation.
28+
"""
29+
import string
30+
text = text.lower()
31+
text = text.translate(str.maketrans('', '', string.punctuation))
32+
return " ".join(text.split())
33+
34+
def benchmark(dataset, model_size="large-v2", device="cuda", compute_type="float16", batch_size=4, limit=None):
35+
print(f"Loading WhisperX model: {model_size} on {device} ({compute_type})...")
36+
37+
try:
38+
model = whisperx.load_model(model_size, device, compute_type=compute_type)
39+
except Exception as e:
40+
print(f"Failed to load model: {e}")
41+
return
42+
43+
print("Model loaded.")
44+
45+
total_wer = 0
46+
total_cer = 0
47+
total_latency = 0
48+
total_audio_duration = 0
49+
count = 0
50+
51+
print(f"\nBenchmarking on {limit if limit else len(dataset)} samples...")
52+
53+
# Clear CUDA cache for accurate VRAM measurement
54+
if torch.cuda.is_available():
55+
torch.cuda.empty_cache()
56+
torch.cuda.reset_peak_memory_stats()
57+
initial_vram = torch.cuda.memory_allocated() / 1024**3
58+
print(f"Initial VRAM usage: {initial_vram:.2f} GB")
59+
60+
for i, item in enumerate(dataset):
61+
if limit and i >= limit:
62+
break
63+
64+
waveform, sample_rate, transcript, talk_id, speaker_id, identifier = item
65+
66+
# WhisperX expects audio as a numpy array, float32, mono, 16kHz
67+
# TEDLIUM is likely 16kHz, but let's verify/resample if needed
68+
# waveform is (channels, time)
69+
70+
if sample_rate != 16000:
71+
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
72+
waveform = resampler(waveform)
73+
74+
audio_np = waveform.squeeze().numpy()
75+
76+
duration = len(audio_np) / 16000
77+
total_audio_duration += duration
78+
79+
# Measure Latency
80+
start_time = time.time()
81+
result = model.transcribe(audio_np, batch_size=batch_size)
82+
end_time = time.time()
83+
84+
latency = end_time - start_time
85+
total_latency += latency
86+
87+
# Combine segments for full transcript
88+
hypothesis = " ".join([seg['text'] for seg in result['segments']])
89+
90+
# Normalize
91+
ref_norm = normalize_text(transcript)
92+
hyp_norm = normalize_text(hypothesis)
93+
94+
if not ref_norm.strip():
95+
# Skip empty references to avoid division by zero in WER
96+
continue
97+
98+
# Measure WER/CER
99+
wer = jiwer.wer(ref_norm, hyp_norm)
100+
cer = jiwer.cer(ref_norm, hyp_norm)
101+
102+
total_wer += wer
103+
total_cer += cer
104+
count += 1
105+
106+
print(f"Sample {i}: WER={wer:.2f}, CER={cer:.2f}, Latency={latency:.2f}s, Dur={duration:.2f}s, RTF={latency/duration:.2f}")
107+
108+
if count == 0:
109+
print("No samples processed.")
110+
return
111+
112+
avg_wer = total_wer / count
113+
avg_cer = total_cer / count
114+
avg_rtf = total_latency / total_audio_duration
115+
116+
print("\n--- Benchmark Results ---")
117+
print(f"Average WER: {avg_wer:.4f}")
118+
print(f"Average CER: {avg_cer:.4f}")
119+
print(f"Average RTF (Real Time Factor): {avg_rtf:.4f}")
120+
print(f"Total Latency: {total_latency:.2f}s for {total_audio_duration:.2f}s audio")
121+
122+
if torch.cuda.is_available():
123+
peak_vram = torch.cuda.max_memory_allocated() / 1024**3
124+
print(f"Peak VRAM Usage: {peak_vram:.2f} GB")
125+
else:
126+
print("VRAM Usage: N/A (CPU only)")
127+
128+
if __name__ == "__main__":
129+
parser = argparse.ArgumentParser(description="Benchmark WhisperX on TEDLIUM")
130+
parser.add_argument("--root", type=str, default="./data", help="Root directory for dataset")
131+
parser.add_argument("--download", action="store_true", help="Download dataset if not found")
132+
parser.add_argument("--limit", type=int, default=None, help="Limit number of samples")
133+
parser.add_argument("--model", type=str, default="large-v2", help="Whisper model size")
134+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device")
135+
parser.add_argument("--batch_size", type=int, default=4, help="Batch size")
136+
137+
args = parser.parse_args()
138+
139+
# Create data dir
140+
os.makedirs(args.root, exist_ok=True)
141+
142+
ds = load_tedlium(args.root, download=args.download)
143+
if ds:
144+
benchmark(ds, model_size=args.model, device=args.device, batch_size=args.batch_size, limit=args.limit)

0 commit comments

Comments
 (0)