Skip to content

Commit 4edfb6e

Browse files
committed
add the option to use the FB M2M100 model for translation
1 parent 1928f77 commit 4edfb6e

File tree

4 files changed

+51
-7
lines changed

4 files changed

+51
-7
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,10 @@ $ subaligner --languages
131131
$ subaligner -m single -v video.mp4 -s subtitle.srt -t src,tgt
132132
$ subaligner -m dual -v video.mp4 -s subtitle.srt -t src,tgt
133133
$ subaligner -m script -v test.mp4 -s subtitle.txt -o subtitle_aligned.srt -t src,tgt
134-
$ subaligner -m dual -v video.mp4 -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
135-
$ subaligner -m dual -v video.mp4 -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
136-
$ subaligner -m dual -v video.mp4 -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
134+
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
135+
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
136+
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-m2m100 -tf small -o subtitle_aligned.srt -t src,tgt
137+
$ subaligner -m dual -v video.mp4 -s subtitle.srt -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
137138
```
138139
```
139140
# Transcribe audiovisual files and generate translated subtitles

site/source/usage.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ Make sure you have got the virtual environment activated upfront.
5151
(.venv) $ subaligner -m single -v video.mp4 -s subtitle.srt -t src,tgt
5252
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -t src,tgt
5353
(.venv) $ subaligner -m script -v test.mp4 -s subtitle.txt -o subtitle_aligned.srt -t src,tgt
54-
(.venv) $ subaligner -m dual -v video.mp4 -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
55-
(.venv) $ subaligner -m dual -v video.mp4 -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
56-
(.venv) $ subaligner -m dual -v video.mp4 -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
54+
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr helsinki-nlp -o subtitle_aligned.srt -t src,tgt
55+
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-mbart -tf large -o subtitle_aligned.srt -t src,tgt
56+
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr facebook-m2m100 -tf small -o subtitle_aligned.srt -t src,tgt
57+
(.venv) $ subaligner -m dual -v video.mp4 -s subtitle.srt -tr whisper -tf small -o subtitle_aligned.srt -t src,eng
5758

5859
**Transcribe audiovisual files and generate translated subtitles**::
5960

subaligner/llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class TranslationRecipe(Enum):
99
HELSINKI_NLP = "helsinki-nlp"
1010
WHISPER = "whisper"
1111
FACEBOOK_MBART = "facebook-mbart"
12+
FACEBOOK_M2M100 = "facebook-m2m100"
1213

1314

1415
class WhisperFlavour(Enum):
@@ -34,3 +35,7 @@ class HelsinkiNLPFlavour(Enum):
3435

3536
class FacebookMbartFlavour(Enum):
3637
LARGE = "large"
38+
39+
40+
class FacebookM2m100Flavour(Enum):
41+
SMALL = "small"

subaligner/translator.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
MarianTokenizer,
1313
MBart50TokenizerFast,
1414
MBartForConditionalGeneration,
15+
M2M100ForConditionalGeneration,
16+
M2M100Tokenizer,
1517
)
1618
from whisper.tokenizer import LANGUAGES
1719
from .singleton import Singleton
18-
from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour
20+
from .llm import TranslationRecipe, HelsinkiNLPFlavour, WhisperFlavour, FacebookMbartFlavour, FacebookM2m100Flavour
1921
from .utils import Utils
2022
from .subtitle import Subtitle
2123
from .logger import Logger
@@ -147,6 +149,27 @@ def translate(self,
147149
new_subs[index].text = translated_texts[index]
148150
self.__LOGGER.info("Subtitle translated")
149151
return new_subs
152+
elif self.__recipe == TranslationRecipe.FACEBOOK_M2M100.value:
153+
src_lang, tgt_lang = language_pair if language_pair is not None else (self.__src_language, self.__tgt_language)
154+
self.__tokenizer.src_lang = Utils.get_iso_639_alpha_2(src_lang)
155+
lang_code = Utils.get_iso_639_alpha_2(tgt_lang)
156+
if src_lang is None or tgt_lang is None:
157+
raise NotImplementedError(
158+
f"Language pair of {src_lang} and {src_lang} is not supported by {self.__recipe}")
159+
translated_texts = []
160+
self.__lang_model.eval()
161+
new_subs = deepcopy(subs)
162+
src_texts = [sub.text for sub in new_subs]
163+
num_of_batches = math.ceil(len(src_texts) / Translator.__TRANSLATING_BATCH_SIZE)
164+
self.__LOGGER.info("Translating %s subtitle cue(s)..." % len(src_texts))
165+
for batch in tqdm(Translator.__batch(src_texts, Translator.__TRANSLATING_BATCH_SIZE), total=num_of_batches):
166+
input_ids = self.__tokenizer(batch, return_tensors=Translator.__TENSOR_TYPE, padding=True)
167+
translated = self.__lang_model.generate(**input_ids, forced_bos_token_id=self.__tokenizer.get_lang_id(lang_code))
168+
translated_texts.extend([self.__tokenizer.decode(t, skip_special_tokens=True) for t in translated])
169+
for index in range(len(new_subs)):
170+
new_subs[index].text = translated_texts[index]
171+
self.__LOGGER.info("Subtitle translated")
172+
return new_subs
150173
else:
151174
return []
152175

@@ -178,6 +201,13 @@ def __initialise_model(self, src_lang: str, tgt_lang: str, recipe: str, flavour:
178201
self.__download_mbart_model(flavour)
179202
else:
180203
raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}")
204+
elif recipe == TranslationRecipe.FACEBOOK_M2M100.value:
205+
if flavour in [f.value for f in FacebookM2m100Flavour]:
206+
self.__download_m2m100_model(flavour)
207+
else:
208+
raise NotImplementedError(f"Unknown {recipe} flavour: {flavour}")
209+
else:
210+
raise NotImplementedError(f"Unknown recipe: {recipe}")
181211

182212
def __download_mt_model(self, src_lang: str, tgt_lang: str, flavour: str) -> bool:
183213
try:
@@ -216,6 +246,13 @@ def __download_mbart_model(self, flavour: str) -> None:
216246
self.__lang_model = MBartForConditionalGeneration.from_pretrained(mbart_model_name)
217247
self.__LOGGER.debug("mBART model %s downloaded" % mbart_model_name)
218248

249+
def __download_m2m100_model(self, flavour: str) -> None:
250+
m2m100_model_name = "facebook/m2m100_418M" if flavour == "small" else "facebook/m2m100_418M"
251+
self.__LOGGER.debug("Trying to download the M2M100 model %s" % m2m100_model_name)
252+
self.__tokenizer = M2M100Tokenizer.from_pretrained(m2m100_model_name)
253+
self.__lang_model = M2M100ForConditionalGeneration.from_pretrained(m2m100_model_name)
254+
self.__LOGGER.debug("M2M100 model %s downloaded" % m2m100_model_name)
255+
219256
def __download_by_mt_name(self, mt_model_name: str) -> None:
220257
self.__LOGGER.debug("Trying to download the MT model %s" % mt_model_name)
221258
self.__tokenizer = MarianTokenizer.from_pretrained(mt_model_name)

0 commit comments

Comments
 (0)