Skip to content

Commit f528f4a

Browse files
authored
split data into train and test and save them in json files
1 parent 52a61fe commit f528f4a

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

data_selection.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
4+
# In[29]:
5+
6+
7+
import math
8+
import torchtext
9+
import torch
10+
import torch.nn as nn
11+
from torchtext.data.utils import get_tokenizer
12+
from collections import Counter
13+
from torchtext.vocab import vocab
14+
from torch import Tensor
15+
import io
16+
import time
17+
import os
18+
import pandas as pd
19+
20+
torch.manual_seed(0)
21+
# PyTorch operations must use “deterministic” algorithms. if not available throw RuntimeError
22+
# torch.use_deterministic_algorithms(True)
23+
24+
25+
from datetime import datetime
26+
27+
save_folder_address = "inference"+str(datetime.now()).replace(" ", "__")
28+
29+
30+
# In[30]:
31+
32+
33+
print("haha")
34+
35+
36+
# In[31]:
37+
38+
39+
features_names = ["maingloss", "domgloss", "ndomgloss", "domreloc", "ndomreloc",
40+
"domhandrelocx", "domhandrelocy", "domhandrelocz", "domhandrelocax",
41+
"domhandrelocay", "domhandrelocaz", "domhandrelocsx", "domhandrelocsy", "domhandrelocsz",
42+
"domhandrotx", "domhandroty", "domhandrotz",
43+
"ndomhandrelocx", "ndomhandrelocy", "ndomhandrelocz", "ndomhandrelocax",
44+
"ndomhandrelocay", "ndomhandrelocaz", "ndomhandrelocsx", "ndomhandrelocsy", "ndomhandrelocsz",
45+
"ndomhandrotx", "ndomhandroty", "ndomhandrotz"]
46+
47+
directory = "mms-subset91"
48+
text_directory = "annotations-full/annotations"
49+
data_list = []
50+
for filename in os.listdir(directory):
51+
f = os.path.join(directory, filename)
52+
df = pd.read_csv(f)
53+
54+
filenumber = filename.split(".")[0]
55+
56+
text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~")
57+
file = open(text_address, encoding='latin-1')
58+
lines = file.readlines()
59+
text_line = ""
60+
for i, text_data in enumerate(lines):
61+
if i>0:
62+
text_line = text_line + " " + text_data.replace("\n", "").split(";")[2]
63+
else:
64+
text_line = text_line + text_data.replace("\n", "").split(";")[2]
65+
66+
data_dict = {"file_ID":filenumber, "text": text_line}
67+
for feature in features_names:
68+
if feature == "domgloss" or feature == "ndomgloss":
69+
temp = df[feature].copy()
70+
data_dict[feature] = [data_dict["maingloss"][i] if pd.isnull(token) else token for i,token in enumerate(temp)]
71+
else:
72+
data_dict[feature] = df[feature].tolist()
73+
data_list.append(data_dict)
74+
75+
76+
# data_list is a list of dictionaries\
77+
# each dictianry corresponds to a data sample in the dataset\
78+
# file_ID is the file number, text is the german sentence, and the rest are all a list of the same length containing different values of gloss, boolean, and real value numbers.
79+
80+
# In[32]:
81+
82+
83+
boolean_map = {"yes": 1, "no": 0}
84+
for data in data_list:
85+
data["domreloc"] = [boolean_map[value] for value in data["domreloc"]]
86+
data["ndomreloc"] = [boolean_map[value] for value in data["ndomreloc"]]
87+
# data["shoulders"] = [boolean_map[value] for value in data["shoulders"]]
88+
89+
90+
# In[33]:
91+
92+
93+
def build_German_vocab(data_list, tokenizer):
94+
"""
95+
a function to build vocabulary
96+
97+
:param filepath: file path of the text file
98+
:param tokenizer: tokenizer related to the text file language
99+
:return: torchtext vocab of a particular language
100+
"""
101+
counter = Counter()
102+
for data in data_list:
103+
tokenized_text = tokenizer(data["text"])
104+
counter.update(tokenized_text)
105+
return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
106+
107+
108+
# In[34]:
109+
110+
111+
def build_gloss_vocab(data_list, gloss_name):
112+
"""
113+
a function to build vocabulary
114+
115+
:param filepath: file path of the text file
116+
:param tokenizer: tokenizer related to the text file language
117+
:return: torchtext vocab of a particular language
118+
"""
119+
counter = Counter()
120+
for data in data_list:
121+
counter.update(data[gloss_name])
122+
return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
123+
124+
125+
# In[35]:
126+
127+
128+
# get spacy tokenizer for German text
129+
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
130+
131+
de_vocab = build_German_vocab(data_list, de_tokenizer)
132+
gl_vocab = build_gloss_vocab(data_list, "maingloss")
133+
dom_vocab = build_gloss_vocab(data_list, "domgloss")
134+
ndom_vocab = build_gloss_vocab(data_list, "ndomgloss")
135+
136+
137+
# I added this two lines because some tokens from the validation and test are not in the train set.
138+
139+
# In[36]:
140+
141+
142+
UNK_IDX = de_vocab['<unk>']
143+
de_vocab.set_default_index(UNK_IDX)
144+
gl_vocab.set_default_index(UNK_IDX)
145+
dom_vocab.set_default_index(UNK_IDX)
146+
ndom_vocab.set_default_index(UNK_IDX)
147+
148+
149+
# In[37]:
150+
151+
152+
# divide train and test here
153+
import math
154+
import random
155+
156+
div = math.floor(len(data_list)*0.75)
157+
158+
data_list_copy = data_list.copy()
159+
160+
random.seed(1)
161+
random.shuffle(data_list_copy)
162+
163+
test_data_raw = data_list_copy[div:len(data_list)]
164+
train_data_raw = data_list_copy[0:div]
165+
166+
167+
# In[38]:
168+
169+
170+
import json
171+
172+
with open("test_data.json", "w") as outfile:
173+
json.dump(test_data_raw, outfile)
174+
175+
with open("train_data.json", "w") as outfile:
176+
json.dump(train_data_raw, outfile)
177+
178+
179+
# In[39]:
180+
181+
182+
with open('test_data.json', 'r') as openfile:
183+
test_data_raw = json.load(openfile)
184+
185+
print(len(test_data_raw))
186+
187+
188+
with open('train_data.json', 'r') as openfile:
189+
train_data_raw = json.load(openfile)
190+
191+
print(len(train_data_raw))
192+
193+
194+
# In[8]:
195+
196+
197+
# with open('../Downloads/test_data.json', 'r') as openfile:
198+
# json_object_server = json.load(openfile)
199+
200+
# print(len(json_object_server))
201+
202+
203+
# In[10]:
204+
205+
206+
# for item in json_object_server:
207+
# print(item['file_ID'])
208+
# print(item["text"])
209+
# print(item["maingloss"])
210+
211+
212+
# In[ ]:
213+
214+
215+
216+

0 commit comments

Comments
 (0)