Skip to content

Commit 74e6ef3

Browse files
authored
Merge pull request #3 from MaithriRao/pretrained-model
Using pretrained models trained on english corpus to perform translation
2 parents be064aa + 20a6d5d commit 74e6ef3

File tree

4 files changed

+713
-0
lines changed

4 files changed

+713
-0
lines changed

gpt_english/dataselection.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import math
2+
import torch
3+
from torch import Tensor
4+
import io
5+
import time
6+
import os
7+
import pandas as pd
8+
import json
9+
from datetime import datetime
10+
import pickle
11+
from pathlib import Path
12+
from torch.utils.data import Dataset
13+
# from transformers import BertTokenizer
14+
from collections import Counter
15+
from itertools import chain
16+
from sklearn.model_selection import train_test_split
17+
from torch.utils.data import DataLoader
18+
from torch.nn.utils.rnn import pad_sequence
19+
import torchtext
20+
from torchtext.data.utils import get_tokenizer
21+
from collections import Counter
22+
from torchtext.vocab import vocab
23+
from sklearn.preprocessing import MinMaxScaler, StandardScaler
24+
import numpy as np
25+
from transformers import GPT2Tokenizer
26+
import torch.nn.functional as F
27+
28+
features_names = ["maingloss", "domgloss", "ndomgloss", "domreloc","ndomreloc", "framestart", "headmov", "domhandrelocx", "domhandrelocy", "domhandrelocz", "domhandrelocax",
29+
"domhandrelocay", "domhandrelocaz", "domhandrelocsx", "domhandrelocsy", "domhandrelocsz",
30+
"domhandrotx", "domhandroty", "domhandrotz",
31+
"ndomhandrelocx", "ndomhandrelocy", "ndomhandrelocz", "ndomhandrelocax",
32+
"ndomhandrelocay", "ndomhandrelocaz", "ndomhandrelocsx", "ndomhandrelocsy", "ndomhandrelocsz",
33+
"ndomhandrotx", "ndomhandroty", "ndomhandrotz"]
34+
35+
HEADMOV_TO_INT = {'0': 0, '1': 1, '2': 2, 'no': 3, 'yes': 4}
36+
37+
directory = "mms-subset91"
38+
text_directory = "annotations_full/annotations"
39+
40+
41+
def load_data(directory, text_directory, features_names):
42+
data_list = []
43+
for filename in os.listdir(directory):
44+
f = os.path.join(directory, filename)
45+
df = pd.read_csv(f)
46+
47+
filenumber = filename.split(".")[0]
48+
49+
text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~")
50+
file = open(text_address, encoding='latin-1')
51+
lines = file.readlines()
52+
text_line = ""
53+
for i, text_data in enumerate(lines):
54+
if i>0:
55+
text_line = text_line + " " + text_data.replace("\n", "").split(";")[2]
56+
else:
57+
text_line = text_line + text_data.replace("\n", "").split(";")[2]
58+
59+
data_dict = {"file_ID":filenumber, "text": text_line}
60+
for feature in features_names:
61+
if feature == "domgloss" or feature == "ndomgloss":
62+
temp = df[feature].copy()
63+
data_dict[feature] = [data_dict["maingloss"][i] if pd.isnull(token) else token for i,token in enumerate(temp)]
64+
else:
65+
data_dict[feature] = df[feature].tolist()
66+
data_list.append(data_dict)
67+
return data_list
68+
69+
70+
def normalize_inflection_features(data_x, features_names):
71+
72+
inflection_stats = {}
73+
for feature in features_names[7:]:
74+
all_values = []
75+
for data in data_x:
76+
if feature in data and data[feature]:
77+
all_values.extend([v for v in data[feature] if not np.isnan(v)])
78+
if all_values:
79+
global_min = min(all_values)
80+
global_max = max(all_values)
81+
mean_value = np.mean(global_min + global_max)
82+
inflection_stats[feature] = {
83+
'global_min': global_min,
84+
'global_max': global_max,
85+
'mean_value': mean_value
86+
}
87+
print(f"{feature} - Global min: {global_min:.4f}, Global max: {global_max:.4f}, Mean: {mean_value:.4f}")
88+
else:
89+
print(f"Warning: No valid values found for {feature}")
90+
inflection_stats[feature] = None
91+
92+
# Preprocess the data to handle missing values
93+
for data in data_x:
94+
for feature in features_names[7:]:
95+
if inflection_stats[feature] is None:
96+
continue
97+
98+
if feature not in data or not data[feature]:
99+
# If missing values, replace with mean value of text length
100+
data[feature] = [inflection_stats[feature]['mean_value']] * len(data['text'].split())
101+
else:
102+
normalized_values = []
103+
for j, value in enumerate(data[feature]):
104+
if np.isnan(value):
105+
normalized_values.append(inflection_stats[feature]['mean_value'])
106+
else:
107+
normalized_value = (value - inflection_stats[feature]['global_min']) / (inflection_stats[feature]['global_max'] - inflection_stats[feature]['global_min'])
108+
if np.isnan(normalized_value):
109+
print(f"Normalization resulted in NaN in file: {data.get('file_ID', 'Unknown')} for feature {feature} at index {j}. Value: {value}")
110+
normalized_value = inflection_stats[feature]['mean_value']
111+
normalized_values.append(normalized_value)
112+
data[feature] = normalized_values
113+
114+
return data_x, inflection_stats
115+
116+
# Tokenize text data
117+
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
118+
data_list = load_data(directory, text_directory, features_names)
119+
# Split data into training and validation sets
120+
train_data, val_data = train_test_split(data_list, test_size=0.25, random_state=42)
121+
normalized_train_data, inflection_stats = normalize_inflection_features(train_data, features_names)
122+
123+
124+
# Define custom dataset class
125+
class SignLanguageDataset(Dataset):
126+
def __init__(self, data_list, tokenizer, inflection_stats, max_length=512):
127+
self.data_list = data_list
128+
self.tokenizer = tokenizer
129+
self.max_length = max_length
130+
self.vocab_size = len(tokenizer)
131+
self.framestart_scaler = StandardScaler()
132+
all_framestart = [item for data in data_list for item in data['framestart']]
133+
self.framestart_scaler.fit(np.array(all_framestart).reshape(-1, 1))
134+
self.inflection_stats = inflection_stats
135+
136+
def __len__(self):
137+
return len(self.data_list)
138+
139+
def __getitem__(self, idx):
140+
data = self.data_list[idx]
141+
file_Id = data['file_ID']
142+
text_tokens = self.tokenizer.encode(data['text'], add_special_tokens=True, max_length=self.max_length, truncation=True)
143+
text_tokens = torch.tensor(text_tokens)
144+
145+
gloss_feature = ["maingloss", "domgloss", "ndomgloss"]
146+
gloss_tokens = {}
147+
for feature in gloss_feature:
148+
149+
if feature in data:
150+
tokens = self.tokenizer.encode(' '.join(data[feature]), add_special_tokens=True, max_length=self.max_length, truncation=True)
151+
gloss_tokens[feature] = torch.tensor(tokens)
152+
153+
feature_tensors = {}
154+
for feature in self.inflection_stats.keys():
155+
if feature in data:
156+
feature_tensors[feature] = torch.tensor(data[feature], dtype=torch.float32)
157+
158+
framestart = self.framestart_scaler.transform(np.array(data['framestart']).reshape(-1, 1)).flatten()
159+
framestart = torch.tensor(framestart, dtype=torch.float32)
160+
161+
boolean_features = {}
162+
for feature in ['domreloc','ndomreloc']:
163+
if feature in data:
164+
boolean_features[feature]= torch.tensor([transform_bool_feature(val) for val in data[feature]], dtype=torch.long)
165+
166+
return file_Id, text_tokens, gloss_tokens, framestart, boolean_features, feature_tensors
167+
168+
def get_scaler(self):
169+
return self.framestart_scaler
170+
171+
def get_inflection_stats(self):
172+
return self.inflection_stats
173+
174+
175+
def transform_bool_feature(val):
176+
val = val.lower().strip()
177+
return HEADMOV_TO_INT.get(val, 0)
178+
179+
180+
def collate_fn(batch):
181+
file_Id, text_tokens, gloss_tokens, framestart, boolean_features, feature_tensors= zip(*batch)
182+
183+
padding_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
184+
185+
text_tokens_padded = torch.nn.utils.rnn.pad_sequence(text_tokens, batch_first=True, padding_value=padding_value)
186+
187+
gloss_tokens_padded = {}
188+
for feature in ['maingloss', 'domgloss', 'ndomgloss']:
189+
if all(feature in sample for sample in gloss_tokens):
190+
gloss_tokens_padded[feature] = pad_sequence([sample[feature] for sample in gloss_tokens],
191+
batch_first=True,
192+
padding_value=padding_value)
193+
# Ensure all have the same sequence length
194+
max_len = max(text_tokens_padded.size(1), max(tensor.size(1) for tensor in gloss_tokens_padded.values()))
195+
196+
text_tokens_padded = torch.nn.functional.pad(text_tokens_padded, (0, max_len - text_tokens_padded.size(1)), value=padding_value)
197+
198+
for feature in gloss_tokens_padded:
199+
gloss_tokens_padded[feature] = torch.nn.functional.pad(gloss_tokens_padded[feature], (0, max_len - gloss_tokens_padded[feature].size(1)), value=padding_value)
200+
201+
boolean_features_padded = {}
202+
for feature in boolean_features[0].keys():
203+
feature_list = [sample[feature] for sample in boolean_features]
204+
padded_feature = pad_sequence(feature_list, batch_first=True, padding_value=-1)
205+
padded_feature = torch.nn.functional.pad(padded_feature, (0, max_len - padded_feature.size(1)), value=-1)
206+
boolean_features_padded[feature] = padded_feature
207+
208+
framestart_padded = pad_sequence(framestart, batch_first=True, padding_value=0.0)
209+
framestart_padded = torch.nn.functional.pad(framestart_padded, (0, max_len - framestart_padded.size(1)), value=0.0)
210+
211+
212+
inflection_features_padded = {}
213+
for feature in feature_tensors[0].keys():
214+
feature_list = [sample[feature] for sample in feature_tensors]
215+
padded_feature = pad_sequence(feature_list, batch_first=True, padding_value=0.0)
216+
padded_feature = torch.nn.functional.pad(padded_feature, (0, max_len - padded_feature.size(1)), value=0.0)
217+
inflection_features_padded[feature] = padded_feature
218+
219+
return file_Id, text_tokens_padded, gloss_tokens_padded, framestart_padded, boolean_features_padded, inflection_features_padded
220+
221+
222+
# Create DataLoader instances
223+
train_dataset = SignLanguageDataset(train_data, tokenizer, inflection_stats)
224+
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
225+
226+
val_dataset = SignLanguageDataset(val_data, tokenizer, inflection_stats)
227+
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
228+

0 commit comments

Comments
 (0)