Skip to content

Commit 406a3e2

Browse files
authored
Merge pull request #3 from overcrash66/develop
improve web ui
2 parents 10a4d34 + cb4b783 commit 406a3e2

File tree

4 files changed

+397
-152
lines changed

4 files changed

+397
-152
lines changed

OpenTranslator/translator.py

Lines changed: 189 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import os
1111
import unicodedata
1212

13+
from transformers import pipeline
14+
1315
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1416

1517
class CustomTranslator:
@@ -19,147 +21,218 @@ def __init__(self, output_dir="output"):
1921
self.translation_method = ""
2022
self.output_dir = output_dir
2123
os.makedirs(self.output_dir, exist_ok=True)
22-
# Initialize other attributes as needed
2324

24-
def load_models(self):
25+
def load_whisper_model(self):
2526
self.processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3")
2627
self.model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3").to(device)
27-
# self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
2828

29-
def process_audio_chunk(self, input_path, target_language, chunk_idx, output_path, translation_method):
30-
try:
31-
if translation_method == 'Local':
32-
self.load_models()
33-
start_time = time.time()
34-
# Load input audio file using librosa
35-
input_waveform, input_sampling_rate = librosa.load(input_path, sr=None, mono=True)
29+
def unload_whisper_model(self):
30+
del self.processor
31+
del self.model
3632

37-
# Convert NumPy array to PyTorch tensor if needed
38-
if not isinstance(input_waveform, torch.Tensor):
39-
input_waveform = torch.tensor(input_waveform)
33+
def load_mbart_model(self):
34+
self.mbart_model = MBartForConditionalGeneration.from_pretrained("SnypzZz/Llama2-13b-Language-translate").to(device)
35+
self.mbart_tokenizer = MBart50TokenizerFast.from_pretrained("SnypzZz/Llama2-13b-Language-translate", src_lang="en_XX", device=device)
4036

41-
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language=target_language, task="translate")
37+
def unload_mbart_model(self):
38+
del self.mbart_model
39+
del self.mbart_tokenizer
40+
41+
def load_tts_model(self):
42+
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
4243

43-
# Ensure the input audio has a proper frame rate
44-
if input_sampling_rate != 16000:
45-
resampler = torchaudio.transforms.Resample(orig_freq=input_sampling_rate, new_freq=16000)
46-
input_waveform = resampler(input_waveform)
44+
def unload_tts_model(self):
45+
del self.tts
4746

48-
# Process the input audio with the processor
49-
input_features = self.processor(input_waveform.numpy(), sampling_rate=16000, return_tensors="pt")
47+
def process_audio_chunk(self, input_path, target_language, chunk_idx, output_path, translation_method , batch_size=4):
48+
try:
49+
start_time = time.time()
50+
51+
self.load_whisper_model()
52+
53+
# Load audio waveform
54+
input_waveform, input_sampling_rate = librosa.load(input_path, sr=None, mono=True)
55+
56+
if not isinstance(input_waveform, torch.Tensor):
57+
input_waveform = torch.tensor(input_waveform)
58+
59+
if input_sampling_rate != 16000:
60+
resampler = torchaudio.transforms.Resample(orig_freq=input_sampling_rate, new_freq=16000)
61+
input_waveform = resampler(torch.tensor(input_waveform).clone().detach()).numpy()
62+
63+
# Prepare forced decoder IDs
64+
forced_decoder_ids = self.processor.get_decoder_prompt_ids(language=target_language, task="translate")
65+
66+
# Create batches of input features
67+
input_features = self.processor(
68+
input_waveform,
69+
sampling_rate=16000,
70+
return_tensors="pt",
71+
padding=True
72+
)
73+
input_features = {k: v.to(device) for k, v in input_features.items()}
74+
input_batches = torch.split(input_features["input_features"], batch_size, dim=0)
75+
76+
# Process batches
77+
transcriptions = []
78+
for batch in input_batches:
79+
with torch.no_grad():
80+
predicted_ids = self.model.generate(batch, forced_decoder_ids=forced_decoder_ids, max_length=448)
81+
transcriptions.extend(self.processor.batch_decode(predicted_ids, skip_special_tokens=True))
82+
83+
# Combine transcriptions
84+
transcription = " ".join(transcriptions)
85+
86+
del input_waveform, input_sampling_rate
87+
88+
end_time = time.time()
89+
execution_time = (end_time - start_time) / 60
90+
print(f"Transcription Execution time: {execution_time:.2f} minutes")
91+
92+
words = transcription.split()
93+
cleaned_words = [words[0]]
94+
for word in words[1:]:
95+
if word != cleaned_words[-1]:
96+
cleaned_words.append(word)
97+
cleaned_str = ' '.join(cleaned_words)
98+
99+
sentences = cleaned_str.split('.')
100+
cleaned_sentences = [sentences[0]]
101+
for sentence in sentences[1:]:
102+
if sentence != cleaned_sentences[-1]:
103+
cleaned_sentences.append(sentence)
104+
cleaned_transcription = '.'.join(cleaned_sentences)
105+
106+
transcription = cleaned_transcription
107+
print('Speech recognition and translate to English text: ' + str(transcription))
108+
109+
Translation_chunk_output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(output_path))[0]}_Translation_chunk{chunk_idx + 1}.wav")
110+
111+
if target_language != 'en' and translation_method == 'Llama2-13b':
112+
print("Local text translation started..")
113+
start_time = time.time()
114+
self.load_mbart_model()
50115

51-
# Move input features to the device used by the model
52-
input_features = {k: v.to(device) for k, v in input_features.items()}
116+
inputs = self.mbart_tokenizer(transcription, return_tensors="pt")
117+
input_ids = inputs["input_ids"].to(device)
53118

54-
# Generate token ids
55-
predicted_ids = self.model.generate(input_features["input_features"], forced_decoder_ids=forced_decoder_ids)
119+
language_mapping = {
120+
"en": "en_XX", "es": "es_XX", "fr": "fr_XX", "de": "de_DE",
121+
"ja": "ja_XX", "ko": "ko_KR", "tr": "tr_TR", "ar": "ar_AR",
122+
"ru": "ru_RU", "he": "he_IL", "hi": "hi_IN", "it": "it_IT",
123+
"pt": "pt_XX", "zh": "zh_CN", "cs": "cs_CZ", "nl": "nl_XX", "pl": "pl_PL",
124+
}
125+
model_target_language = language_mapping.get(target_language, "en_XX")
56126

57-
# Decode token ids to text
58-
transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
127+
# Generate tokens on the GPU
128+
generated_tokens = self.mbart_model.generate(input_ids=input_ids, forced_bos_token_id=self.mbart_tokenizer.lang_code_to_id[model_target_language])
59129

60-
del input_waveform, input_sampling_rate
130+
# Decode and join the translated text
131+
translated_text = self.mbart_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
132+
translated_text = ", ".join(translated_text)
61133

134+
self.unload_mbart_model()
135+
136+
print('Mbart Translation: '+ str(translated_text))
62137
end_time = time.time()
63138
execution_time = (end_time - start_time) / 60
64139
print(f"Transcription Execution time: {execution_time:.2f} minutes")
65140

66-
# Fix a bug: Text validation check if we have duplicate successive words
67-
words = transcription.split()
68-
cleaned_words = [words[0]]
69-
70-
for word in words[1:]:
71-
if word != cleaned_words[-1]:
72-
cleaned_words.append(word)
73-
74-
cleaned_str = ' '.join(cleaned_words)
75-
76-
transcription = cleaned_str
77-
78-
# Fix duplicate successive sentences
79-
sentences = transcription.split('.')
80-
cleaned_sentences = [sentences[0]]
81-
82-
for sentence in sentences[1:]:
83-
if sentence != cleaned_sentences[-1]:
84-
cleaned_sentences.append(sentence)
85-
86-
cleaned_transcription = '.'.join(cleaned_sentences)
87-
88-
transcription = cleaned_transcription
89-
print('Speech recognition and translate to English text: ' + str(transcription))
90-
91-
Translation_chunk_output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(output_path))[0]}_Translation_chunk{chunk_idx + 1}.wav")
92-
93-
# If target language is English, skip text translation
94-
if target_language != 'en':
95-
# Local text translation
96-
print("Local text translation started..")
97-
start_time = time.time()
98-
tt = MBartForConditionalGeneration.from_pretrained("SnypzZz/Llama2-13b-Language-translate").to(device)
99-
tokenizer = MBart50TokenizerFast.from_pretrained("SnypzZz/Llama2-13b-Language-translate", src_lang="en_XX", device=device)
100-
101-
# Tokenize and convert to PyTorch tensor
102-
inputs = tokenizer(transcription, return_tensors="pt")
103-
input_ids = inputs["input_ids"].to(device)
104-
105-
# Map target languages to model language codes
106-
language_mapping = {
107-
"en": "en_XX",
108-
"es": "es_XX",
109-
"fr": "fr_XX",
110-
"de": "de_DE",
111-
"ja": "ja_XX",
112-
"ko": "ko_KR",
113-
"tr": "tr_TR",
114-
"ar": "ar_AR",
115-
"ru": "ru_RU",
116-
"he": "he_IL",
117-
"hi": "hi_IN",
118-
"it": "it_IT",
119-
"pt": "pt_XX",
120-
"zh": "zh_CN",
121-
"cs": "cs_CZ",
122-
"nl": "nl_XX",
123-
"pl": "pl_PL",
124-
}
125-
126-
# Set the target language based on the mapping
127-
model_target_language = language_mapping.get(target_language, "en_XX")
128-
129-
# Generate tokens on the GPU
130-
generated_tokens = tt.generate(input_ids=input_ids, forced_bos_token_id=tokenizer.lang_code_to_id[model_target_language])
131-
132-
# Decode and join the translated text
133-
translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
134-
translated_text = ", ".join(translated_text)
135-
136-
logging.info(f"Processing successful. Translated text: {translated_text}")
137-
end_time = time.time()
138-
execution_time = (end_time - start_time) / 60
139-
print(f"Local Translation Execution time: {execution_time:.2f} minutes")
140-
141-
if target_language == 'en':
142-
translated_text = transcription
143-
144-
# Generate final audio output from translated text
145-
self.generate_audio(translated_text, Translation_chunk_output_path, target_language, input_path)
146-
147-
# Log success
148-
logging.info(f"Translation successful for {input_path}. Translated text: {transcription}")
149-
return translated_text
141+
if target_language == 'en':
142+
translated_text = transcription
143+
144+
if target_language != 'en' and translation_method == 'TowerInstruct-7B':
145+
translated_text = self.validate_translation(transcription, target_language)
146+
147+
self.generate_audio(translated_text, Translation_chunk_output_path, target_language, input_path)
148+
149+
return translated_text
150+
self.unload_whisper_model()
150151

151152
except Exception as e:
152-
# Log errors
153153
logging.error(f"Error processing audio: {e}")
154-
raise # Re-raise the exception
154+
return "An Error occurred!", None
155+
156+
def validate_translation(self, source_text, target_language):
157+
print('validate_translation started ..')
158+
start_time = time.time()
159+
160+
languages = {
161+
"English": "en",
162+
"Spanish": "es",
163+
"French": "fr",
164+
"German": "de",
165+
"Korean": "ko",
166+
"Russian": "ru",
167+
"Italian": "it",
168+
"Portuguese": "pt",
169+
"Chinese (Mandarin)": "zh",
170+
"Dutch": "nl"
171+
}
172+
173+
code_to_language = {code: lang for lang, code in languages.items()}
174+
target_language = code_to_language.get(target_language, "Unknown language")
175+
176+
#supports 10 languages: English, German, French, Spanish, Chinese, Portuguese, Italian, Russian, Korean, and Dutch
177+
pipe = pipeline("text-generation", model="Unbabel/TowerInstruct-7B-v0.2", torch_dtype=torch.bfloat16, device_map=device)
178+
# We use the tokenizer’s chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating
179+
messages = [
180+
{
181+
"role": "user",
182+
"content": (
183+
f"Translate the following text from English into {target_language}.\n"
184+
f"English: {source_text}\n"
185+
f"{target_language}:"
186+
),
187+
}
188+
]
189+
190+
#print(target_language)
191+
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
192+
outputs = pipe(prompt, max_new_tokens=256, do_sample=False)
193+
generated_text = outputs[0]["generated_text"]
194+
195+
#translated_text = generated_text.split("English:")[-1].strip()
196+
197+
# Further sanitize to remove undesired formatting tokens
198+
generated_text = (
199+
generated_text.replace("<|im_start|>", "")
200+
.replace("<|im_end|>", "")
201+
.strip()
202+
)
203+
204+
# Define the unwanted substrings in a list
205+
unwanted_substrings = [
206+
target_language,
207+
source_text,
208+
'assistant',
209+
'Translate the following text from English into .',
210+
'\n',
211+
'English:',
212+
':'
213+
]
214+
215+
# Remove the unwanted substrings
216+
translated_text = generated_text.split("\n", 1)[-1].strip() # Split and strip the first line
217+
for substring in unwanted_substrings:
218+
translated_text = translated_text.replace(substring, '')
219+
220+
print(f'validate_translation: {translated_text}')
221+
end_time = time.time()
222+
execution_time = (end_time - start_time) / 60
223+
print(f"Generate_audio Execution time: {execution_time:.2f} minutes")
224+
return translated_text
155225

156226
def generate_audio(self, text, output_path, target_language, input_path):
157227
print("Generate audio")
158-
159-
# Text to speech to a file
160228
start_time = time.time()
161-
self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
229+
230+
self.load_tts_model()
231+
162232
self.tts.tts_to_file(text=text, speaker_wav=input_path, language=target_language, file_path=output_path)
233+
163234
end_time = time.time()
164235
execution_time = (end_time - start_time) / 60
165236
print(f"Generate_audio Execution time: {execution_time:.2f} minutes")
237+
238+
self.unload_tts_model()

Screenshot.png

37.6 KB
Loading

0 commit comments

Comments
 (0)