11import os
22import whisper
33import torch
4- from typing import Tuple , Optional
4+ import concurrent .futures
5+ import math
6+ import multiprocessing as mp
7+ import torchaudio
8+ import numpy as np
9+ from functools import partial
10+ from threading import Lock
11+ from typing import Tuple , Optional , Dict , List
512from pysrt import SubRipTime
13+ from whisper import Whisper
614from whisper .tokenizer import LANGUAGES
15+ from tqdm import tqdm
716from .subtitle import Subtitle
817from .media_helper import MediaHelper
918from .llm import TranscriptionRecipe , WhisperFlavour
@@ -38,14 +47,20 @@ def __init__(self, recipe: str = TranscriptionRecipe.WHISPER.value, flavour: str
3847 self .__flavour = flavour
3948 self .__media_helper = MediaHelper ()
4049 self .__LOGGER = Logger ().get_logger (__name__ )
50+ self .__lock = Lock ()
4151
42- def transcribe (self , video_file_path : str , language_code : str , initial_prompt : Optional [str ] = None ) -> Tuple [Subtitle , Optional [float ]]:
52+ def transcribe (self ,
53+ video_file_path : str ,
54+ language_code : str ,
55+ initial_prompt : Optional [str ] = None ,
56+ max_char_length : Optional [int ] = None ) -> Tuple [Subtitle , Optional [float ]]:
4357 """Transcribe an audiovisual file and generate subtitles.
4458
4559 Arguments:
4660 video_file_path {string} -- The input video file path.
4761 language_code {string} -- An alpha 3 language code derived from ISO 639-3.
48- initial_prompt {string} -- Optional text to provide the transcribing context or specific phrases.
62+ initial_prompt {string} -- Optional Text to provide the transcribing context or specific phrases.
63+ max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
4964
5065 Returns:
5166 tuple: Generated subtitle after transcription and the detected frame rate
@@ -64,14 +79,24 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
6479 self .__LOGGER .info ("Start transcribing the audio..." )
6580 verbose = False if Logger .VERBOSE and not Logger .QUIET else None
6681 self .__LOGGER .debug ("Prompting with: '%s'" % initial_prompt )
67- result = self .__model .transcribe (audio , task = "transcribe" , language = LANGUAGES [lang ], verbose = verbose , initial_prompt = initial_prompt )
82+ result = self .__model .transcribe (audio ,
83+ task = "transcribe" ,
84+ language = LANGUAGES [lang ],
85+ verbose = verbose ,
86+ word_timestamps = True ,
87+ initial_prompt = initial_prompt )
6888 self .__LOGGER .info ("Finished transcribing the audio" )
6989 srt_str = ""
70- for i , segment in enumerate (result ["segments" ], start = 1 ):
71- srt_str += f"{ i } \n " \
72- f"{ Utils .format_timestamp (segment ['start' ])} --> { Utils .format_timestamp (segment ['end' ])} \n " \
73- f"{ segment ['text' ].strip ().replace ('-->' , '->' )} \n " \
74- "\n "
90+ srt_idx = 1
91+ for segment in result ["segments" ]:
92+ if max_char_length is not None and len (segment ["text" ]) > max_char_length :
93+ srt_str , srt_idx = self ._chunk_segment (segment , srt_str , srt_idx , max_char_length )
94+ else :
95+ srt_str += f"{ srt_idx } \n " \
96+ f"{ Utils .format_timestamp (segment ['words' ][0 ]['start' ])} --> { Utils .format_timestamp (segment ['words' ][- 1 ]['end' ])} \n " \
97+ f"{ segment ['text' ].strip ().replace ('-->' , '->' )} \n " \
98+ "\n "
99+ srt_idx += 1
75100 subtitle = Subtitle .load_subrip_str (srt_str )
76101 subtitle , frame_rate = self .__on_frame_timecodes (subtitle , video_file_path )
77102 self .__LOGGER .debug ("Generated the raw subtitle" )
@@ -82,13 +107,19 @@ def transcribe(self, video_file_path: str, language_code: str, initial_prompt: O
82107 else :
83108 raise NotImplementedError (f"{ self .__recipe } ({ self .__flavour } ) is not supported" )
84109
85- def transcribe_with_subtitle_as_prompts (self , video_file_path : str , subtitle_file_path : str , language_code : str ) -> Tuple [Subtitle , Optional [float ]]:
86- """Transcribe an audiovisual file and generate subtitles using the original subtitle as prompts.
110+ def transcribe_with_subtitle_as_prompts (self ,
111+ video_file_path : str ,
112+ subtitle_file_path : str ,
113+ language_code : str ,
114+ max_char_length : Optional [int ] = None ) -> Tuple [Subtitle , Optional [float ]]:
115+ """Transcribe an audiovisual file and generate subtitles using the original subtitle (with accurate time codes) as prompts.
116+
87117
88118 Arguments:
89119 video_file_path {string} -- The input video file path.
90120 subtitle_file_path {string} -- The input subtitle file path to provide prompts.
91121 language_code {string} -- An alpha 3 language code derived from ISO 639-3.
122+ max_char_length {int} -- Optional Maximum number of characters for each generated subtitle segment.
92123
93124 Returns:
94125 tuple: Generated subtitle after transcription and the detected frame rate
@@ -104,27 +135,54 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
104135 f'"{ language_code } " is not supported by { self .__recipe } ({ self .__flavour } )' )
105136 audio_file_path = self .__media_helper .extract_audio (video_file_path , True , 16000 )
106137 subtitle = Subtitle .load (subtitle_file_path )
107- segment_paths = []
138+ segment_paths : List [ str ] = []
108139 try :
109140 srt_str = ""
110141 srt_idx = 1
111142 self .__LOGGER .info ("Start transcribing the audio..." )
112- verbose = False if Logger .VERBOSE and not Logger .QUIET else None
113- for sub in subtitle .subs :
143+ segment_paths = []
144+ args = []
145+ longest_segment_char_length = 0
146+ for sub in tqdm (subtitle .subs , desc = "Extracting audio segments" ):
114147 segment_path , _ = self .__media_helper .extract_audio_from_start_to_end (audio_file_path , str (sub .start ), str (sub .end ))
115148 segment_paths .append (segment_path )
116- audio = whisper .load_audio (segment_path )
117- result = self .__model .transcribe (audio , task = "transcribe" , language = LANGUAGES [lang ], verbose = verbose , initial_prompt = sub .text )
149+ args .append ((segment_path , sub .text , self .__lock , self .__LOGGER ))
150+ if len (sub .text ) > longest_segment_char_length :
151+ longest_segment_char_length = len (sub .text )
152+ max_subtitle_char_length = max_char_length or longest_segment_char_length
153+
154+ max_workers = math .ceil (float (os .getenv ("MAX_WORKERS" , mp .cpu_count () / 2 )))
155+ with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
156+ results = list (executor .map (partial (self ._whisper_transcribe , model = self .__model , lang = lang ), args ))
157+ for sub , result in zip (subtitle .subs , results ):
118158 original_start_in_secs = sub .start .hours * 3600 + sub .start .minutes * 60 + sub .start .seconds + sub .start .milliseconds / 1000.0
119159 original_end_in_secs = sub .end .hours * 3600 + sub .end .minutes * 60 + sub .end .seconds + sub .end .milliseconds / 1000.0
120- for segment in result ["segments" ]:
121- if segment ["end" ] <= segment ["start" ]:
122- continue
160+ if len (result ["segments" ]) == 0 :
123161 srt_str += f"{ srt_idx } \n " \
124- f"{ Utils .format_timestamp (original_start_in_secs + segment [ 'start' ] )} --> { Utils .format_timestamp (min ( original_start_in_secs + segment [ 'end' ], original_end_in_secs ) )} \n " \
125- f"{ segment [ ' text' ] .strip ().replace ('-->' , '->' )} \n " \
162+ f"{ Utils .format_timestamp (original_start_in_secs )} --> { Utils .format_timestamp (original_end_in_secs )} \n " \
163+ f"{ sub . text .strip ().replace ('-->' , '->' )} \n " \
126164 "\n "
127165 srt_idx += 1
166+ else :
167+ for segment in result ["segments" ]:
168+ if segment ["end" ] <= segment ["start" ]:
169+ continue
170+ segment_end = min (original_start_in_secs + segment ["end" ], original_end_in_secs )
171+ if len (segment ["text" ]) > max_subtitle_char_length :
172+ srt_str , srt_idx = self ._chunk_segment (segment ,
173+ srt_str ,
174+ srt_idx ,
175+ max_subtitle_char_length ,
176+ original_start_in_secs ,
177+ original_end_in_secs )
178+ else :
179+ srt_str += f"{ srt_idx } \n " \
180+ f"{ Utils .format_timestamp (original_start_in_secs + segment ['start' ])} --> { Utils .format_timestamp (segment_end )} \n " \
181+ f"{ segment ['text' ].strip ().replace ('-->' , '->' )} \n " \
182+ "\n "
183+ srt_idx += 1
184+ if segment_end == original_end_in_secs :
185+ break
128186 self .__LOGGER .info ("Finished transcribing the audio" )
129187 subtitle = Subtitle .load_subrip_str (srt_str )
130188 subtitle , frame_rate = self .__on_frame_timecodes (subtitle , video_file_path )
@@ -139,6 +197,66 @@ def transcribe_with_subtitle_as_prompts(self, video_file_path: str, subtitle_fil
139197 else :
140198 raise NotImplementedError (f"{ self .__recipe } ({ self .__flavour } ) is not supported" )
141199
200+ @staticmethod
201+ def _whisper_transcribe (args : Tuple , model : Whisper , lang : str ) -> Dict :
202+ segment_path , sub_text , lock , logger = args
203+ verbose = False if Logger .VERBOSE and not Logger .QUIET else None
204+ try :
205+ waveform , _ = torchaudio .load (segment_path )
206+ if waveform .shape [0 ] > 1 :
207+ waveform = waveform .mean (dim = 0 )
208+ waveform = waveform .numpy ().astype (np .float32 )
209+ with lock :
210+ result = model .transcribe (waveform ,
211+ task = "transcribe" ,
212+ language = LANGUAGES [lang ],
213+ verbose = verbose ,
214+ initial_prompt = sub_text ,
215+ word_timestamps = True )
216+ logger .debug ("Segment transcribed : %s" , result )
217+ return result
218+ except Exception as e :
219+ logger .warning (f"Error while transcribing segment: { e } " )
220+ return {"segments" : []}
221+
222+ @staticmethod
223+ def _chunk_segment (segment : Dict ,
224+ srt_str : str ,
225+ srt_idx : int ,
226+ max_subtitle_char_length : int ,
227+ start_offset : float = 0.0 ,
228+ end_ceiling : float = float ("inf" )) -> Tuple [str , int ]:
229+ chunked_text = ""
230+ chunk_start_in_secs = 0.0
231+ chunk_end_in_secs = 0.0
232+ chunk_char_length = 0
233+
234+ for word in segment ["words" ]:
235+ if chunk_char_length + len (word ["word" ]) > max_subtitle_char_length and chunked_text .strip () != "" :
236+ srt_str += f"{ srt_idx } \n " \
237+ f"{ Utils .format_timestamp (start_offset + chunk_start_in_secs )} --> { Utils .format_timestamp (min (start_offset + chunk_end_in_secs , end_ceiling ))} \n " \
238+ f"{ chunked_text .strip ().replace ('-->' , '->' )} \n " \
239+ "\n "
240+ srt_idx += 1
241+ chunked_text = word ["word" ]
242+ chunk_start_in_secs = word ["start" ]
243+ chunk_char_length = len (word ["word" ])
244+ else :
245+ if chunk_start_in_secs == 0.0 :
246+ chunk_start_in_secs = word ["start" ]
247+ chunked_text += word ["word" ]
248+ chunk_char_length += len (word ["word" ])
249+ chunk_end_in_secs = word ["end" ]
250+
251+ if len (chunked_text ) > 0 :
252+ srt_str += f"{ srt_idx } \n " \
253+ f"{ Utils .format_timestamp (start_offset + chunk_start_in_secs )} --> { Utils .format_timestamp (min (start_offset + chunk_end_in_secs , end_ceiling ))} \n " \
254+ f"{ chunked_text .strip ().replace ('-->' , '->' )} \n " \
255+ "\n "
256+ srt_idx += 1
257+
258+ return srt_str , srt_idx
259+
142260 def __on_frame_timecodes (self , subtitle : Subtitle , video_file_path : str ) -> Tuple [Subtitle , Optional [float ]]:
143261 frame_rate = None
144262 try :
0 commit comments