|
12 | 12 | MarianTokenizer, |
13 | 13 | MBart50TokenizerFast, |
14 | 14 | MBartForConditionalGeneration, |
| 15 | + M2M100ForConditionalGeneration, |
| 16 | + M2M100Tokenizer, |
15 | 17 | ) |
16 | 18 | from whisper.tokenizer import LANGUAGES |
17 | 19 | from .singleton import Singleton |
18 | | -from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour |
| 20 | +from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour, FacebookM2m100Flavour |
19 | 21 | from .utils import Utils |
20 | 22 | from .subtitle import Subtitle |
21 | 23 | from .logger import Logger |
@@ -147,6 +149,27 @@ def translate(self, |
147 | 149 | new_subs[index].text = translated_texts[index] |
148 | 150 | self.__LOGGER.info("Subtitle translated") |
149 | 151 | return new_subs |
| 152 | + elif self.__recipe == TranslationRecipe.FACEBOOK_M2M100.value: |
| 153 | + src_lang, tgt_lang = language_pair if language_pair is not None else (self.__src_language, self.__tgt_language) |
| 154 | + self.__tokenizer.src_lang = Utils.get_iso_639_alpha_2(src_lang) |
| 155 | + lang_code = Utils.get_iso_639_alpha_2(tgt_lang) |
| 156 | + if src_lang is None or tgt_lang is None: |
| 157 | + raise NotImplementedError( |
| 158 | + f"Language pair of {src_lang} and {src_lang} is not supported by {self.__recipe}") |
| 159 | + translated_texts = [] |
| 160 | + self.__lang_model.eval() |
| 161 | + new_subs = deepcopy(subs) |
| 162 | + src_texts = [sub.text for sub in new_subs] |
| 163 | + num_of_batches = math.ceil(len(src_texts) / Translator.__TRANSLATING_BATCH_SIZE) |
| 164 | + self.__LOGGER.info("Translating %s subtitle cue(s)..." % len(src_texts)) |
| 165 | + for batch in tqdm(Translator.__batch(src_texts, Translator.__TRANSLATING_BATCH_SIZE), total=num_of_batches): |
| 166 | + input_ids = self.__tokenizer(batch, return_tensors=Translator.__TENSOR_TYPE, padding=True) |
| 167 | + translated = self.__lang_model.generate(**input_ids, forced_bos_token_id=self.__tokenizer.get_lang_id(lang_code)) |
| 168 | + translated_texts.extend([self.__tokenizer.decode(t, skip_special_tokens=True) for t in translated]) |
| 169 | + for index in range(len(new_subs)): |
| 170 | + new_subs[index].text = translated_texts[index] |
| 171 | + self.__LOGGER.info("Subtitle translated") |
| 172 | + return new_subs |
150 | 173 | else: |
151 | 174 | return [] |
152 | 175 |
|
@@ -178,6 +201,13 @@ def __initialise_model(self, src_lang: str, tgt_lang: str, recipe: str, flavour: |
178 | 201 | self.__download_mbart_model(flavour) |
179 | 202 | else: |
180 | 203 | raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}") |
| 204 | + elif recipe == TranslationRecipe.FACEBOOK_M2M100.value: |
| 205 | + if flavour in [f.value for f in FacebookM2m100Flavour]: |
| 206 | + self.__download_m2m100_model(flavour) |
| 207 | + else: |
| 208 | + raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}") |
| 209 | + else: |
| 210 | + raise NotImplementedError(f"Unknown recipe: {recipe}") |
181 | 211 |
|
182 | 212 | def __download_mt_model(self, src_lang: str, tgt_lang: str, flavour: str) -> bool: |
183 | 213 | try: |
@@ -216,6 +246,13 @@ def __download_mbart_model(self, flavour: str) -> None: |
216 | 246 | self.__lang_model = MBartForConditionalGeneration.from_pretrained(mbart_model_name) |
217 | 247 | self.__LOGGER.debug("mBART model %s downloaded" % mbart_model_name) |
218 | 248 |
|
| 249 | + def __download_m2m100_model(self, flavour: str) -> None: |
| 250 | + m2m100_model_name = "facebook/m2m100_418M" if flavour == "small" else "facebook/m2m100_418M" |
| 251 | + self.__LOGGER.debug("Trying to download the M2M100 model %s" % m2m100_model_name) |
| 252 | + self.__tokenizer = M2M100Tokenizer.from_pretrained(m2m100_model_name) |
| 253 | + self.__lang_model = M2M100ForConditionalGeneration.from_pretrained(m2m100_model_name) |
| 254 | + self.__LOGGER.debug("M2M100 model %s downloaded" % m2m100_model_name) |
| 255 | + |
219 | 256 | def __download_by_mt_name(self, mt_model_name: str) -> None: |
220 | 257 | self.__LOGGER.debug("Trying to download the MT model %s" % mt_model_name) |
221 | 258 | self.__tokenizer = MarianTokenizer.from_pretrained(mt_model_name) |
|
0 commit comments