1010import os
1111import unicodedata
1212
13+ from transformers import pipeline
14+
1315device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
1416
1517class CustomTranslator :
@@ -19,147 +21,218 @@ def __init__(self, output_dir="output"):
1921 self .translation_method = ""
2022 self .output_dir = output_dir
2123 os .makedirs (self .output_dir , exist_ok = True )
22- # Initialize other attributes as needed
2324
24- def load_models (self ):
25+ def load_whisper_model (self ):
2526 self .processor = WhisperProcessor .from_pretrained ("openai/whisper-large-v3" )
2627 self .model = WhisperForConditionalGeneration .from_pretrained ("openai/whisper-large-v3" ).to (device )
27- # self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
2828
29- def process_audio_chunk (self , input_path , target_language , chunk_idx , output_path , translation_method ):
30- try :
31- if translation_method == 'Local' :
32- self .load_models ()
33- start_time = time .time ()
34- # Load input audio file using librosa
35- input_waveform , input_sampling_rate = librosa .load (input_path , sr = None , mono = True )
29+ def unload_whisper_model (self ):
30+ del self .processor
31+ del self .model
3632
37- # Convert NumPy array to PyTorch tensor if needed
38- if not isinstance ( input_waveform , torch . Tensor ):
39- input_waveform = torch . tensor ( input_waveform )
33+ def load_mbart_model ( self ):
34+ self . mbart_model = MBartForConditionalGeneration . from_pretrained ( "SnypzZz/Llama2-13b-Language-translate" ). to ( device )
35+ self . mbart_tokenizer = MBart50TokenizerFast . from_pretrained ( "SnypzZz/Llama2-13b-Language-translate" , src_lang = "en_XX" , device = device )
4036
41- forced_decoder_ids = self .processor .get_decoder_prompt_ids (language = target_language , task = "translate" )
37+ def unload_mbart_model (self ):
38+ del self .mbart_model
39+ del self .mbart_tokenizer
40+
41+ def load_tts_model (self ):
42+ self .tts = TTS ("tts_models/multilingual/multi-dataset/xtts_v2" ).to (device )
4243
43- # Ensure the input audio has a proper frame rate
44- if input_sampling_rate != 16000 :
45- resampler = torchaudio .transforms .Resample (orig_freq = input_sampling_rate , new_freq = 16000 )
46- input_waveform = resampler (input_waveform )
44+ def unload_tts_model (self ):
45+ del self .tts
4746
48- # Process the input audio with the processor
49- input_features = self .processor (input_waveform .numpy (), sampling_rate = 16000 , return_tensors = "pt" )
47+ def process_audio_chunk (self , input_path , target_language , chunk_idx , output_path , translation_method , batch_size = 4 ):
48+ try :
49+ start_time = time .time ()
50+
51+ self .load_whisper_model ()
52+
53+ # Load audio waveform
54+ input_waveform , input_sampling_rate = librosa .load (input_path , sr = None , mono = True )
55+
56+ if not isinstance (input_waveform , torch .Tensor ):
57+ input_waveform = torch .tensor (input_waveform )
58+
59+ if input_sampling_rate != 16000 :
60+ resampler = torchaudio .transforms .Resample (orig_freq = input_sampling_rate , new_freq = 16000 )
61+ input_waveform = resampler (torch .tensor (input_waveform ).clone ().detach ()).numpy ()
62+
63+ # Prepare forced decoder IDs
64+ forced_decoder_ids = self .processor .get_decoder_prompt_ids (language = target_language , task = "translate" )
65+
66+ # Create batches of input features
67+ input_features = self .processor (
68+ input_waveform ,
69+ sampling_rate = 16000 ,
70+ return_tensors = "pt" ,
71+ padding = True
72+ )
73+ input_features = {k : v .to (device ) for k , v in input_features .items ()}
74+ input_batches = torch .split (input_features ["input_features" ], batch_size , dim = 0 )
75+
76+ # Process batches
77+ transcriptions = []
78+ for batch in input_batches :
79+ with torch .no_grad ():
80+ predicted_ids = self .model .generate (batch , forced_decoder_ids = forced_decoder_ids , max_length = 448 )
81+ transcriptions .extend (self .processor .batch_decode (predicted_ids , skip_special_tokens = True ))
82+
83+ # Combine transcriptions
84+ transcription = " " .join (transcriptions )
85+
86+ del input_waveform , input_sampling_rate
87+
88+ end_time = time .time ()
89+ execution_time = (end_time - start_time ) / 60
90+ print (f"Transcription Execution time: { execution_time :.2f} minutes" )
91+
92+ words = transcription .split ()
93+ cleaned_words = [words [0 ]]
94+ for word in words [1 :]:
95+ if word != cleaned_words [- 1 ]:
96+ cleaned_words .append (word )
97+ cleaned_str = ' ' .join (cleaned_words )
98+
99+ sentences = cleaned_str .split ('.' )
100+ cleaned_sentences = [sentences [0 ]]
101+ for sentence in sentences [1 :]:
102+ if sentence != cleaned_sentences [- 1 ]:
103+ cleaned_sentences .append (sentence )
104+ cleaned_transcription = '.' .join (cleaned_sentences )
105+
106+ transcription = cleaned_transcription
107+ print ('Speech recognition and translate to English text: ' + str (transcription ))
108+
109+ Translation_chunk_output_path = os .path .join (self .output_dir , f"{ os .path .splitext (os .path .basename (output_path ))[0 ]} _Translation_chunk{ chunk_idx + 1 } .wav" )
110+
111+ if target_language != 'en' and translation_method == 'Llama2-13b' :
112+ print ("Local text translation started.." )
113+ start_time = time .time ()
114+ self .load_mbart_model ()
50115
51- # Move input features to the device used by the model
52- input_features = { k : v .to (device ) for k , v in input_features . items ()}
116+ inputs = self . mbart_tokenizer ( transcription , return_tensors = "pt" )
117+ input_ids = inputs [ "input_ids" ] .to (device )
53118
54- # Generate token ids
55- predicted_ids = self .model .generate (input_features ["input_features" ], forced_decoder_ids = forced_decoder_ids )
119+ language_mapping = {
120+ "en" : "en_XX" , "es" : "es_XX" , "fr" : "fr_XX" , "de" : "de_DE" ,
121+ "ja" : "ja_XX" , "ko" : "ko_KR" , "tr" : "tr_TR" , "ar" : "ar_AR" ,
122+ "ru" : "ru_RU" , "he" : "he_IL" , "hi" : "hi_IN" , "it" : "it_IT" ,
123+ "pt" : "pt_XX" , "zh" : "zh_CN" , "cs" : "cs_CZ" , "nl" : "nl_XX" , "pl" : "pl_PL" ,
124+ }
125+ model_target_language = language_mapping .get (target_language , "en_XX" )
56126
57- # Decode token ids to text
58- transcription = self .processor . batch_decode ( predicted_ids , skip_special_tokens = True )[ 0 ]
127+ # Generate tokens on the GPU
128+ generated_tokens = self .mbart_model . generate ( input_ids = input_ids , forced_bos_token_id = self . mbart_tokenizer . lang_code_to_id [ model_target_language ])
59129
60- del input_waveform , input_sampling_rate
130+ # Decode and join the translated text
131+ translated_text = self .mbart_tokenizer .batch_decode (generated_tokens , skip_special_tokens = True )
132+ translated_text = ", " .join (translated_text )
61133
134+ self .unload_mbart_model ()
135+
136+ print ('Mbart Translation: ' + str (translated_text ))
62137 end_time = time .time ()
63138 execution_time = (end_time - start_time ) / 60
64139 print (f"Transcription Execution time: { execution_time :.2f} minutes" )
65140
66- # Fix a bug: Text validation check if we have duplicate successive words
67- words = transcription .split ()
68- cleaned_words = [words [0 ]]
69-
70- for word in words [1 :]:
71- if word != cleaned_words [- 1 ]:
72- cleaned_words .append (word )
73-
74- cleaned_str = ' ' .join (cleaned_words )
75-
76- transcription = cleaned_str
77-
78- # Fix duplicate successive sentences
79- sentences = transcription .split ('.' )
80- cleaned_sentences = [sentences [0 ]]
81-
82- for sentence in sentences [1 :]:
83- if sentence != cleaned_sentences [- 1 ]:
84- cleaned_sentences .append (sentence )
85-
86- cleaned_transcription = '.' .join (cleaned_sentences )
87-
88- transcription = cleaned_transcription
89- print ('Speech recognition and translate to English text: ' + str (transcription ))
90-
91- Translation_chunk_output_path = os .path .join (self .output_dir , f"{ os .path .splitext (os .path .basename (output_path ))[0 ]} _Translation_chunk{ chunk_idx + 1 } .wav" )
92-
93- # If target language is English, skip text translation
94- if target_language != 'en' :
95- # Local text translation
96- print ("Local text translation started.." )
97- start_time = time .time ()
98- tt = MBartForConditionalGeneration .from_pretrained ("SnypzZz/Llama2-13b-Language-translate" ).to (device )
99- tokenizer = MBart50TokenizerFast .from_pretrained ("SnypzZz/Llama2-13b-Language-translate" , src_lang = "en_XX" , device = device )
100-
101- # Tokenize and convert to PyTorch tensor
102- inputs = tokenizer (transcription , return_tensors = "pt" )
103- input_ids = inputs ["input_ids" ].to (device )
104-
105- # Map target languages to model language codes
106- language_mapping = {
107- "en" : "en_XX" ,
108- "es" : "es_XX" ,
109- "fr" : "fr_XX" ,
110- "de" : "de_DE" ,
111- "ja" : "ja_XX" ,
112- "ko" : "ko_KR" ,
113- "tr" : "tr_TR" ,
114- "ar" : "ar_AR" ,
115- "ru" : "ru_RU" ,
116- "he" : "he_IL" ,
117- "hi" : "hi_IN" ,
118- "it" : "it_IT" ,
119- "pt" : "pt_XX" ,
120- "zh" : "zh_CN" ,
121- "cs" : "cs_CZ" ,
122- "nl" : "nl_XX" ,
123- "pl" : "pl_PL" ,
124- }
125-
126- # Set the target language based on the mapping
127- model_target_language = language_mapping .get (target_language , "en_XX" )
128-
129- # Generate tokens on the GPU
130- generated_tokens = tt .generate (input_ids = input_ids , forced_bos_token_id = tokenizer .lang_code_to_id [model_target_language ])
131-
132- # Decode and join the translated text
133- translated_text = tokenizer .batch_decode (generated_tokens , skip_special_tokens = True )
134- translated_text = ", " .join (translated_text )
135-
136- logging .info (f"Processing successful. Translated text: { translated_text } " )
137- end_time = time .time ()
138- execution_time = (end_time - start_time ) / 60
139- print (f"Local Translation Execution time: { execution_time :.2f} minutes" )
140-
141- if target_language == 'en' :
142- translated_text = transcription
143-
144- # Generate final audio output from translated text
145- self .generate_audio (translated_text , Translation_chunk_output_path , target_language , input_path )
146-
147- # Log success
148- logging .info (f"Translation successful for { input_path } . Translated text: { transcription } " )
149- return translated_text
141+ if target_language == 'en' :
142+ translated_text = transcription
143+
144+ if target_language != 'en' and translation_method == 'TowerInstruct-7B' :
145+ translated_text = self .validate_translation (transcription , target_language )
146+
147+ self .generate_audio (translated_text , Translation_chunk_output_path , target_language , input_path )
148+
149+ return translated_text
150+ self .unload_whisper_model ()
150151
151152 except Exception as e :
152- # Log errors
153153 logging .error (f"Error processing audio: { e } " )
154- raise # Re-raise the exception
154+ return "An Error occurred!" , None
155+
156+ def validate_translation (self , source_text , target_language ):
157+ print ('validate_translation started ..' )
158+ start_time = time .time ()
159+
160+ languages = {
161+ "English" : "en" ,
162+ "Spanish" : "es" ,
163+ "French" : "fr" ,
164+ "German" : "de" ,
165+ "Korean" : "ko" ,
166+ "Russian" : "ru" ,
167+ "Italian" : "it" ,
168+ "Portuguese" : "pt" ,
169+ "Chinese (Mandarin)" : "zh" ,
170+ "Dutch" : "nl"
171+ }
172+
173+ code_to_language = {code : lang for lang , code in languages .items ()}
174+ target_language = code_to_language .get (target_language , "Unknown language" )
175+
176+ #supports 10 languages: English, German, French, Spanish, Chinese, Portuguese, Italian, Russian, Korean, and Dutch
177+ pipe = pipeline ("text-generation" , model = "Unbabel/TowerInstruct-7B-v0.2" , torch_dtype = torch .bfloat16 , device_map = device )
178+ # We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
179+ messages = [
180+ {
181+ "role" : "user" ,
182+ "content" : (
183+ f"Translate the following text from English into { target_language } .\n "
184+ f"English: { source_text } \n "
185+ f"{ target_language } :"
186+ ),
187+ }
188+ ]
189+
190+ #print(target_language)
191+ prompt = pipe .tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
192+ outputs = pipe (prompt , max_new_tokens = 256 , do_sample = False )
193+ generated_text = outputs [0 ]["generated_text" ]
194+
195+ #translated_text = generated_text.split("English:")[-1].strip()
196+
197+ # Further sanitize to remove undesired formatting tokens
198+ generated_text = (
199+ generated_text .replace ("<|im_start|>" , "" )
200+ .replace ("<|im_end|>" , "" )
201+ .strip ()
202+ )
203+
204+ # Define the unwanted substrings in a list
205+ unwanted_substrings = [
206+ target_language ,
207+ source_text ,
208+ 'assistant' ,
209+ 'Translate the following text from English into .' ,
210+ '\n ' ,
211+ 'English:' ,
212+ ':'
213+ ]
214+
215+ # Remove the unwanted substrings
216+ translated_text = generated_text .split ("\n " , 1 )[- 1 ].strip () # Split and strip the first line
217+ for substring in unwanted_substrings :
218+ translated_text = translated_text .replace (substring , '' )
219+
220+ print (f'validate_translation: { translated_text } ' )
221+ end_time = time .time ()
222+ execution_time = (end_time - start_time ) / 60
223+ print (f"Generate_audio Execution time: { execution_time :.2f} minutes" )
224+ return translated_text
155225
156226 def generate_audio (self , text , output_path , target_language , input_path ):
157227 print ("Generate audio" )
158-
159- # Text to speech to a file
160228 start_time = time .time ()
161- self .tts = TTS ("tts_models/multilingual/multi-dataset/xtts_v2" ).to (device )
229+
230+ self .load_tts_model ()
231+
162232 self .tts .tts_to_file (text = text , speaker_wav = input_path , language = target_language , file_path = output_path )
233+
163234 end_time = time .time ()
164235 execution_time = (end_time - start_time ) / 60
165236 print (f"Generate_audio Execution time: { execution_time :.2f} minutes" )
237+
238+ self .unload_tts_model ()
0 commit comments