1818
1919logger = get_logger (__name__ )
2020
21+
2122def find_numeral_symbol_tokens (tokenizer ):
2223 numeral_symbol_tokens = []
2324 for i in range (tokenizer .eot ):
@@ -39,48 +40,28 @@ def generate_segment_batched(
3940 tokenizer : Tokenizer ,
4041 options : TranscriptionOptions ,
4142 encoder_output = None ,
42- use_batch_context : bool = False ,
43- previous_batch_context_tokens : List [List [int ]] = None ,
4443 ):
4544 batch_size = features .shape [0 ]
46- if previous_batch_context_tokens is None :
47- previous_batch_context_tokens = [[] for _ in range (batch_size )]
48-
49- initial_prompt_tokens = []
45+ all_tokens = []
46+ prompt_reset_since = 0
5047 if options .initial_prompt is not None :
5148 initial_prompt = " " + options .initial_prompt .strip ()
5249 initial_prompt_tokens = tokenizer .encode (initial_prompt )
53-
54- batch_tokens = []
55- for i in range (batch_size ):
56- all_tokens = list (initial_prompt_tokens )
57- if use_batch_context :
58- if i < len (previous_batch_context_tokens ):
59- ctx = previous_batch_context_tokens [i ]
60- if ctx :
61- # 223 is max prompt tokens
62- available = 223 - len (all_tokens )
63- if available > 0 :
64- all_tokens .extend (ctx [- available :])
65- batch_tokens .append (all_tokens )
66-
67- max_batch_tokens = max ([len (t ) for t in batch_tokens ] + [0 ])
68-
69- prompts = [
70- self .get_prompt (
71- tokenizer ,
72- [tokenizer .eot ] * (max_batch_tokens - len (t )) + t ,
73- without_timestamps = options .without_timestamps ,
74- prefix = options .prefix ,
75- hotwords = options .hotwords
76- ) for t in batch_tokens
77- ]
50+ all_tokens .extend (initial_prompt_tokens )
51+ previous_tokens = all_tokens [prompt_reset_since :]
52+ prompt = self .get_prompt (
53+ tokenizer ,
54+ previous_tokens ,
55+ without_timestamps = options .without_timestamps ,
56+ prefix = options .prefix ,
57+ hotwords = options .hotwords
58+ )
7859
7960 encoder_output = self .encode (features )
8061
8162 result = self .model .generate (
8263 encoder_output ,
83- prompts ,
64+ [ prompt ] * batch_size ,
8465 beam_size = options .beam_size ,
8566 patience = options .patience ,
8667 length_penalty = options .length_penalty ,
@@ -101,9 +82,9 @@ def decode_batch(tokens: List[List[int]]) -> List[str]:
10182 return tokenizer .tokenizer .decode_batch (res )
10283
10384 text = decode_batch (tokens_batch )
85+
10486 return text
10587
106-
10788 def encode (self , features : np .ndarray ) -> ctranslate2 .StorageView :
10889 # When the model is running on multiple GPUs, the encoder output should be moved
10990 # to the CPU since we don't know which GPU will handle the next job.
@@ -134,15 +115,13 @@ def __init__(
134115 framework = "pt" ,
135116 language : Optional [str ] = None ,
136117 suppress_numerals : bool = False ,
137- use_batch_context : bool = False ,
138118 ** kwargs ,
139119 ):
140120 self .model = model
141121 self .tokenizer = tokenizer
142122 self .options = options
143123 self .preset_language = language
144124 self .suppress_numerals = suppress_numerals
145- self .use_batch_context = use_batch_context
146125 self ._batch_size = kwargs .pop ("batch_size" , None )
147126 self ._num_workers = 1
148127 self ._preprocess_params , self ._forward_params , self ._postprocess_params = self ._sanitize_parameters (** kwargs )
@@ -163,8 +142,6 @@ def __init__(
163142 super (Pipeline , self ).__init__ ()
164143 self .vad_model = vad
165144 self ._vad_params = vad_params
166- self .previous_batch_context_tokens = []
167-
168145
169146 def _sanitize_parameters (self , ** kwargs ):
170147 preprocess_kwargs = {}
@@ -183,35 +160,7 @@ def preprocess(self, audio):
183160 return {'inputs' : features }
184161
185162 def _forward (self , model_inputs ):
186- current_batch_size = model_inputs ['inputs' ].shape [0 ]
187- # Ideally, batch[i] corresponds to stream[i].
188- # This holds if batch_size == number of streams.
189- valid_contexts = self .previous_batch_context_tokens [:current_batch_size ]
190-
191- outputs = self .model .generate_segment_batched (
192- model_inputs ['inputs' ],
193- self .tokenizer ,
194- self .options ,
195- use_batch_context = self .use_batch_context ,
196- previous_batch_context_tokens = valid_contexts ,
197- )
198- if self .use_batch_context :
199- initial_prompt_length = 0
200- if self .options .initial_prompt is not None :
201- initial_prompt = " " + self .options .initial_prompt .strip ()
202- initial_prompt_length = len (self .tokenizer .encode (initial_prompt ))
203-
204- # Use 220 instead of 224 to be safe
205- max_context_window = max (0 , 220 - initial_prompt_length )
206-
207- for i , text in enumerate (outputs ):
208- if i < len (self .previous_batch_context_tokens ):
209- # Filter out special tokens (timestamps, SOT, EOT, etc.)
210- # We only want the text content for context.
211- tokens = [t for t in self .tokenizer .encode (text ) if t < self .tokenizer .eot ]
212- self .previous_batch_context_tokens [i ].extend (tokens )
213- self .previous_batch_context_tokens [i ] = self .previous_batch_context_tokens [i ][- max_context_window :]
214-
163+ outputs = self .model .generate_segment_batched (model_inputs ['inputs' ], self .tokenizer , self .options )
215164 return {'text' : outputs }
216165
217166 def postprocess (self , model_outputs ):
@@ -252,14 +201,6 @@ def transcribe(
252201 ) -> TranscriptionResult :
253202 if isinstance (audio , str ):
254203 audio = load_audio (audio )
255-
256- batch_size = batch_size or self ._batch_size
257- # Initialize context for each stream.
258- # We have 'batch_size' concurrent streams.
259- if batch_size is None or batch_size < 1 :
260- batch_size = 1
261-
262- self .previous_batch_context_tokens = [[] for _ in range (batch_size )]
263204
264205 def data (audio , segments ):
265206 for seg in segments :
@@ -311,33 +252,10 @@ def data(audio, segments):
311252 new_suppressed_tokens = numeral_symbol_tokens + self .options .suppress_tokens
312253 new_suppressed_tokens = list (set (new_suppressed_tokens ))
313254 self .options = replace (self .options , suppress_tokens = new_suppressed_tokens )
314-
255+
315256 segments : List [SingleSegment ] = []
316257 batch_size = batch_size or self ._batch_size
317258 total_segments = len (vad_segments )
318-
319- if batch_size > 1 and self .use_batch_context :
320- num_streams = batch_size
321- # Distribute segments into streams
322- # Manual split
323- k , m = divmod (len (vad_segments ), num_streams )
324- # lengths of each part: first m parts have k+1, rest have k
325- stream_segments = []
326- start_idx = 0
327- for i in range (num_streams ):
328- part_len = k + 1 if i < m else k
329- stream_segments .append (vad_segments [start_idx : start_idx + part_len ])
330- start_idx += part_len
331- # Interleave
332- # We need to pick [s0[0], s1[0], s2[0]... s0[1], s1[1]...]
333- interleaved_segments = []
334- max_len = max (len (s ) for s in stream_segments )
335- for i in range (max_len ):
336- for stream in stream_segments :
337- if i < len (stream ):
338- interleaved_segments .append (stream [i ])
339- vad_segments = interleaved_segments
340-
341259 for idx , out in enumerate (self .__call__ (data (audio , vad_segments ), batch_size = batch_size , num_workers = num_workers )):
342260 if print_progress :
343261 base_progress = ((idx + 1 ) / total_segments ) * 100
@@ -356,25 +274,6 @@ def data(audio, segments):
356274 }
357275 )
358276
359- if self .use_batch_context and batch_size > 1 :
360- last_stream_index = (total_segments - 1 ) % batch_size
361- final_context = self .previous_batch_context_tokens [last_stream_index ]
362- # Prepare context for the wrap-around re-run
363- # ONLY Stream 0 (which processes the start of the file) should get the context (which comes from the end of the file).
364- # All other streams should have EMPTY context for this re-run to avoid self-referencing loops (feeding Segment N to Segment N).
365- new_rerun_context = [[] for _ in range (batch_size )]
366- new_rerun_context [0 ] = final_context
367- # Temporarily overwrite previous_batch_context_tokens for the re-run
368- self .previous_batch_context_tokens = new_rerun_context
369- first_batch_segments = vad_segments [:batch_size ]
370- # Runs the model again just on 'first_batch_segments'
371- for i , out in enumerate (self .__call__ (data (audio , first_batch_segments ), batch_size = batch_size , num_workers = num_workers )):
372- text = out ['text' ]
373- # L398: Overwrite the existing text with the new wrap-around text
374- segments [i ]['text' ] = text
375- # Sort segments by start time to restore original order
376- segments .sort (key = lambda x : x ['start' ])
377-
378277 # revert the tokenizer if multilingual inference is enabled
379278 if self .preset_language is None :
380279 self .tokenizer = None
@@ -390,8 +289,8 @@ def detect_language(self, audio: np.ndarray) -> str:
390289 logger .warning ("Audio is shorter than 30s, language detection may be inaccurate" )
391290 model_n_mels = self .model .feat_kwargs .get ("feature_size" )
392291 segment = log_mel_spectrogram (audio [: N_SAMPLES ],
393- n_mels = model_n_mels if model_n_mels is not None else 80 ,
394- padding = 0 if audio .shape [0 ] >= N_SAMPLES else N_SAMPLES - audio .shape [0 ])
292+ n_mels = model_n_mels if model_n_mels is not None else 80 ,
293+ padding = 0 if audio .shape [0 ] >= N_SAMPLES else N_SAMPLES - audio .shape [0 ])
395294 encoder_output = self .model .encode (segment )
396295 results = self .model .model .detect_language (encoder_output )
397296 language_token , language_probability = results [0 ][0 ]
@@ -416,7 +315,6 @@ def load_model(
416315 local_files_only = False ,
417316 threads = 4 ,
418317 use_auth_token : Optional [Union [str , bool ]] = None ,
419- use_batch_context : bool = False ,
420318) -> FasterWhisperPipeline :
421319 """Load a Whisper model for inference.
422320 Args:
@@ -523,5 +421,4 @@ def load_model(
523421 language = language ,
524422 suppress_numerals = suppress_numerals ,
525423 vad_params = default_vad_options ,
526- use_batch_context = use_batch_context ,
527424 )
0 commit comments