1818
1919logger = get_logger (__name__ )
2020
21-
2221def find_numeral_symbol_tokens (tokenizer ):
2322 numeral_symbol_tokens = []
2423 for i in range (tokenizer .eot ):
@@ -40,28 +39,48 @@ def generate_segment_batched(
4039 tokenizer : Tokenizer ,
4140 options : TranscriptionOptions ,
4241 encoder_output = None ,
42+ use_batch_context : bool = False ,
43+ previous_batch_context_tokens : List [List [int ]] = None ,
4344 ):
4445 batch_size = features .shape [0 ]
45- all_tokens = []
46- prompt_reset_since = 0
46+ if previous_batch_context_tokens is None :
47+ previous_batch_context_tokens = [[] for _ in range (batch_size )]
48+
49+ initial_prompt_tokens = []
4750 if options .initial_prompt is not None :
4851 initial_prompt = " " + options .initial_prompt .strip ()
4952 initial_prompt_tokens = tokenizer .encode (initial_prompt )
50- all_tokens .extend (initial_prompt_tokens )
51- previous_tokens = all_tokens [prompt_reset_since :]
52- prompt = self .get_prompt (
53- tokenizer ,
54- previous_tokens ,
55- without_timestamps = options .without_timestamps ,
56- prefix = options .prefix ,
57- hotwords = options .hotwords
58- )
53+
54+ batch_tokens = []
55+ for i in range (batch_size ):
56+ all_tokens = list (initial_prompt_tokens )
57+ if use_batch_context :
58+ if i < len (previous_batch_context_tokens ):
59+ ctx = previous_batch_context_tokens [i ]
60+ if ctx :
61+ # 223 is max prompt tokens
62+ available = 223 - len (all_tokens )
63+ if available > 0 :
64+ all_tokens .extend (ctx [- available :])
65+ batch_tokens .append (all_tokens )
66+
67+ max_batch_tokens = max ([len (t ) for t in batch_tokens ] + [0 ])
68+
69+ prompts = [
70+ self .get_prompt (
71+ tokenizer ,
72+ [tokenizer .eot ] * (max_batch_tokens - len (t )) + t ,
73+ without_timestamps = options .without_timestamps ,
74+ prefix = options .prefix ,
75+ hotwords = options .hotwords
76+ ) for t in batch_tokens
77+ ]
5978
6079 encoder_output = self .encode (features )
6180
6281 result = self .model .generate (
6382 encoder_output ,
64- [ prompt ] * batch_size ,
83+ prompts ,
6584 beam_size = options .beam_size ,
6685 patience = options .patience ,
6786 length_penalty = options .length_penalty ,
@@ -82,9 +101,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]:
82101 return tokenizer .tokenizer .decode_batch (res )
83102
84103 text = decode_batch (tokens_batch )
85-
86104 return text
87105
106+
88107 def encode (self , features : np .ndarray ) -> ctranslate2 .StorageView :
89108 # When the model is running on multiple GPUs, the encoder output should be moved
90109 # to the CPU since we don't know which GPU will handle the next job.
@@ -115,13 +134,15 @@ def __init__(
115134 framework = "pt" ,
116135 language : Optional [str ] = None ,
117136 suppress_numerals : bool = False ,
137+ use_batch_context : bool = False ,
118138 ** kwargs ,
119139 ):
120140 self .model = model
121141 self .tokenizer = tokenizer
122142 self .options = options
123143 self .preset_language = language
124144 self .suppress_numerals = suppress_numerals
145+ self .use_batch_context = use_batch_context
125146 self ._batch_size = kwargs .pop ("batch_size" , None )
126147 self ._num_workers = 1
127148 self ._preprocess_params , self ._forward_params , self ._postprocess_params = self ._sanitize_parameters (** kwargs )
@@ -142,6 +163,8 @@ def __init__(
142163 super (Pipeline , self ).__init__ ()
143164 self .vad_model = vad
144165 self ._vad_params = vad_params
166+ self .previous_batch_context_tokens = []
167+
145168
146169 def _sanitize_parameters (self , ** kwargs ):
147170 preprocess_kwargs = {}
@@ -160,7 +183,35 @@ def preprocess(self, audio):
160183 return {'inputs' : features }
161184
162185 def _forward (self , model_inputs ):
163- outputs = self .model .generate_segment_batched (model_inputs ['inputs' ], self .tokenizer , self .options )
186+ current_batch_size = model_inputs ['inputs' ].shape [0 ]
187+ # Ideally, batch[i] corresponds to stream[i].
188+ # This holds if batch_size == number of streams.
189+ valid_contexts = self .previous_batch_context_tokens [:current_batch_size ]
190+
191+ outputs = self .model .generate_segment_batched (
192+ model_inputs ['inputs' ],
193+ self .tokenizer ,
194+ self .options ,
195+ use_batch_context = self .use_batch_context ,
196+ previous_batch_context_tokens = valid_contexts ,
197+ )
198+ if self .use_batch_context :
199+ initial_prompt_length = 0
200+ if self .options .initial_prompt is not None :
201+ initial_prompt = " " + self .options .initial_prompt .strip ()
202+ initial_prompt_length = len (self .tokenizer .encode (initial_prompt ))
203+
204+ # Use 220 instead of 224 to be safe
205+ max_context_window = max (0 , 220 - initial_prompt_length )
206+
207+ for i , text in enumerate (outputs ):
208+ if i < len (self .previous_batch_context_tokens ):
209+ # Filter out special tokens (timestamps, SOT, EOT, etc.)
210+ # We only want the text content for context.
211+ tokens = [t for t in self .tokenizer .encode (text ) if t < self .tokenizer .eot ]
212+ self .previous_batch_context_tokens [i ].extend (tokens )
213+ self .previous_batch_context_tokens [i ] = self .previous_batch_context_tokens [i ][- max_context_window :]
214+
164215 return {'text' : outputs }
165216
166217 def postprocess (self , model_outputs ):
@@ -201,6 +252,14 @@ def transcribe(
201252 ) -> TranscriptionResult :
202253 if isinstance (audio , str ):
203254 audio = load_audio (audio )
255+
256+ batch_size = batch_size or self ._batch_size
257+ # Initialize context for each stream.
258+ # We have 'batch_size' concurrent streams.
259+ if batch_size is None or batch_size < 1 :
260+ batch_size = 1
261+
262+ self .previous_batch_context_tokens = [[] for _ in range (batch_size )]
204263
205264 def data (audio , segments ):
206265 for seg in segments :
@@ -252,10 +311,33 @@ def data(audio, segments):
252311 new_suppressed_tokens = numeral_symbol_tokens + self .options .suppress_tokens
253312 new_suppressed_tokens = list (set (new_suppressed_tokens ))
254313 self .options = replace (self .options , suppress_tokens = new_suppressed_tokens )
255-
314+
256315 segments : List [SingleSegment ] = []
257316 batch_size = batch_size or self ._batch_size
258317 total_segments = len (vad_segments )
318+
319+ if batch_size > 1 and self .use_batch_context :
320+ num_streams = batch_size
321+ # Distribute segments into streams
322+ # Manual split
323+ k , m = divmod (len (vad_segments ), num_streams )
324+ # lengths of each part: first m parts have k+1, rest have k
325+ stream_segments = []
326+ start_idx = 0
327+ for i in range (num_streams ):
328+ part_len = k + 1 if i < m else k
329+ stream_segments .append (vad_segments [start_idx : start_idx + part_len ])
330+ start_idx += part_len
331+ # Interleave
332+ # We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...]
333+ interleaved_segments = []
334+ max_len = max (len (s ) for s in stream_segments )
335+ for i in range (max_len ):
336+ for stream in stream_segments :
337+ if i < len (stream ):
338+ interleaved_segments .append (stream [i ])
339+ vad_segments = interleaved_segments
340+
259341 for idx , out in enumerate (self .__call__ (data (audio , vad_segments ), batch_size = batch_size , num_workers = num_workers )):
260342 if print_progress :
261343 base_progress = ((idx + 1 ) / total_segments ) * 100
@@ -274,6 +356,25 @@ def data(audio, segments):
274356 }
275357 )
276358
359+ if self .use_batch_context and batch_size > 1 :
360+ last_stream_index = (total_segments - 1 ) % batch_size
361+ final_context = self .previous_batch_context_tokens [last_stream_index ]
362+ # Prepare context for the wrap-around re-run
363+ # ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file).
364+ # All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N).
365+ new_rerun_context = [[] for _ in range (batch_size )]
366+ new_rerun_context [0 ] = final_context
367+ # Temporarily overwrite previous_batch_context_tokens for the re-run
368+ self .previous_batch_context_tokens = new_rerun_context
369+ first_batch_segments = vad_segments [:batch_size ]
370+ # Runs the model again just on 'first_batch_segments'
371+ for i , out in enumerate (self .__call__ (data (audio , first_batch_segments ), batch_size = batch_size , num_workers = num_workers )):
372+ text = out ['text' ]
373+ # L398: Overwrite the existing text with the new wrap-around text
374+ segments [i ]['text' ] = text
375+ # Sort segments by start time to restore original order
376+ segments .sort (key = lambda x : x ['start' ])
377+
277378 # revert the tokenizer if multilingual inference is enabled
278379 if self .preset_language is None :
279380 self .tokenizer = None
@@ -289,8 +390,8 @@ def detect_language(self, audio: np.ndarray) -> str:
289390 logger .warning ("Audio is shorter than 30s, language detection may be inaccurate" )
290391 model_n_mels = self .model .feat_kwargs .get ("feature_size" )
291392 segment = log_mel_spectrogram (audio [: N_SAMPLES ],
292- n_mels = model_n_mels if model_n_mels is not None else 80 ,
293- padding = 0 if audio .shape [0 ] >= N_SAMPLES else N_SAMPLES - audio .shape [0 ])
393+ n_mels = model_n_mels if model_n_mels is not None else 80 ,
394+ padding = 0 if audio .shape [0 ] >= N_SAMPLES else N_SAMPLES - audio .shape [0 ])
294395 encoder_output = self .model .encode (segment )
295396 results = self .model .model .detect_language (encoder_output )
296397 language_token , language_probability = results [0 ][0 ]
@@ -315,6 +416,7 @@ def load_model(
315416 local_files_only = False ,
316417 threads = 4 ,
317418 use_auth_token : Optional [Union [str , bool ]] = None ,
419+ use_batch_context : bool = False ,
318420) -> FasterWhisperPipeline :
319421 """Load a Whisper model for inference.
320422 Args:
@@ -421,4 +523,5 @@ def load_model(
421523 language = language ,
422524 suppress_numerals = suppress_numerals ,
423525 vad_params = default_vad_options ,
526+ use_batch_context = use_batch_context ,
424527 )
0 commit comments