Skip to content

Commit 52b95a6

Browse files
authored
Merge pull request #4 from MaithriRao/nllb
Performing translation of text to gloss and gloss to text using pretrained NLLB model
2 parents 74e6ef3 + 5867652 commit 52b95a6

File tree

3 files changed

+447
-0
lines changed

3 files changed

+447
-0
lines changed

nllb/data_selection.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import math
2+
import torch
3+
import torch.nn as nn
4+
from collections import Counter
5+
from torch import Tensor
6+
import io
7+
import time
8+
import os
9+
import pandas as pd
10+
import json
11+
from datetime import datetime
12+
13+
features_names = ["maingloss", "domgloss", "ndomgloss", "domreloc", "ndomreloc",
14+
"domhandrelocx", "domhandrelocy", "domhandrelocz", "domhandrelocax",
15+
"domhandrelocay", "domhandrelocaz", "domhandrelocsx", "domhandrelocsy", "domhandrelocsz",
16+
"domhandrotx", "domhandroty", "domhandrotz",
17+
"ndomhandrelocx", "ndomhandrelocy", "ndomhandrelocz", "ndomhandrelocax",
18+
"ndomhandrelocay", "ndomhandrelocaz", "ndomhandrelocsx", "ndomhandrelocsy", "ndomhandrelocsz",
19+
"ndomhandrotx", "ndomhandroty", "ndomhandrotz"]
20+
21+
def read(text_info, mms_info):
22+
data_list = []
23+
(text_directory, text_encoding) = text_info
24+
print("text_directory: ", text_directory)
25+
(mms_directory, mms_encoding) = mms_info
26+
for filenumber in os.listdir(text_directory):
27+
f = os.path.join(mms_directory, filenumber+".mms")
28+
try:
29+
df = pd.read_csv(f, encoding=mms_encoding)
30+
except FileNotFoundError as e:
31+
print(f"WARNING: Text file exists while mms file does not, skipping: {e}")
32+
continue
33+
34+
text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~")
35+
file = open(text_address, encoding=text_encoding)
36+
lines = file.readlines()
37+
text_line = ""
38+
for i, text_data in enumerate(lines):
39+
if i>0:
40+
text_line = text_line + " " + text_data.replace("\n", "").split(";")[2]
41+
else:
42+
text_line = text_line + text_data.replace("\n", "").split(";")[2]
43+
44+
data_dict = {"file_ID":filenumber, "text": text_line}
45+
for feature in features_names:
46+
if feature == "domgloss" or feature == "ndomgloss":
47+
temp = df[feature].copy()
48+
data_dict[feature] = [data_dict["maingloss"][i] if pd.isnull(token) else token for i,token in enumerate(temp)]
49+
else:
50+
data_dict[feature] = df[feature].tolist()
51+
data_list.append(data_dict)
52+
return data_list

nllb/datasets.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import math
2+
import torch
3+
from torch import Tensor
4+
import io
5+
import time
6+
import os
7+
import pandas as pd
8+
import json
9+
from datetime import datetime
10+
import pickle
11+
from pathlib import Path
12+
from torch.utils.data import Dataset
13+
from collections import Counter
14+
from torch.nn.utils.rnn import pad_sequence
15+
import torchtext
16+
from torchtext.data.utils import get_tokenizer
17+
from collections import Counter
18+
from torchtext.vocab import vocab
19+
import numpy as np
20+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
21+
import torch.nn.functional as F
22+
from pathlib import Path
23+
from . import data_selection
24+
25+
mms_directories = [
26+
("mms-subset91", 'latin-1'),
27+
("modified/location/mms", 'utf-8'),
28+
("modified/platform/mms", 'utf-8'),
29+
("modified/time/mms", 'utf-8'),
30+
("modified/train_name/mms", 'utf-8'),
31+
]
32+
text_directories = [
33+
("annotations_full/annotations", 'latin-1'),
34+
("modified/location/text", 'utf-8'),
35+
("modified/platform/text", 'utf-8'),
36+
("modified/time/text", 'utf-8'),
37+
("modified/train_name/text", 'utf-8'),
38+
]
39+
40+
checkpoint = 'facebook/nllb-200-distilled-600M' #for nllb
41+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
42+
43+
def read():
44+
data_list_only_original = []
45+
data_list_only_modified = []
46+
for i, text_info in enumerate(text_directories):
47+
mms_info = mms_directories[i]
48+
data_list_one = data_selection.read(text_info, mms_info)
49+
if i <= 0:
50+
data_list_only_original += data_list_one
51+
else:
52+
data_list_only_modified += data_list_one
53+
54+
data_list_full = data_list_only_original + data_list_only_modified
55+
56+
return (data_list_only_original, data_list_only_modified, data_list_full)
57+
58+
59+
class SignLanguageDataset(Dataset):
60+
def __init__(self, data_list, tokenizer, max_length=512):
61+
self.data_list = data_list
62+
self.tokenizer = tokenizer
63+
self.max_length = max_length
64+
self.vocab_size = len(tokenizer)
65+
66+
def __len__(self):
67+
return len(self.data_list)
68+
69+
def __getitem__(self, idx):
70+
data = self.data_list[idx]
71+
file_Id = data['file_ID']
72+
text_tokens = self.tokenizer.encode(data['text'], add_special_tokens=True)
73+
text_tokens = torch.tensor(text_tokens)
74+
75+
maingloss_tokens = self.tokenizer.encode(' '.join(data['maingloss']).lower(), add_special_tokens=True)
76+
maingloss_tokens = torch.tensor(maingloss_tokens)
77+
78+
return file_Id, text_tokens, maingloss_tokens
79+
80+
81+
def collate_fn(batch):
82+
file_Id, text_tokens, maingloss_tokens = zip(*batch)
83+
padding_value = tokenizer.pad_token_id # here for nllb paddign token is 1
84+
85+
text_tokens_padded = torch.nn.utils.rnn.pad_sequence(text_tokens, batch_first=True, padding_value=padding_value)
86+
maingloss_tokens_padded = torch.nn.utils.rnn.pad_sequence(maingloss_tokens, batch_first=True, padding_value=padding_value)
87+
88+
# Ensure all have the same sequence length
89+
max_len = max(text_tokens_padded.size(1), maingloss_tokens_padded.size(1))
90+
91+
text_tokens_padded = torch.nn.functional.pad(text_tokens_padded, (0, max_len - text_tokens_padded.size(1)), value=padding_value)
92+
maingloss_tokens_padded = torch.nn.functional.pad(maingloss_tokens_padded, (0, max_len - maingloss_tokens_padded.size(1)), value=padding_value)
93+
94+
return file_Id, text_tokens_padded, maingloss_tokens_padded

0 commit comments

Comments
 (0)