|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding: utf-8 |
| 3 | + |
| 4 | +# In[29]: |
| 5 | + |
| 6 | + |
| 7 | +import math |
| 8 | +import torchtext |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torchtext.data.utils import get_tokenizer |
| 12 | +from collections import Counter |
| 13 | +from torchtext.vocab import vocab |
| 14 | +from torch import Tensor |
| 15 | +import io |
| 16 | +import time |
| 17 | +import os |
| 18 | +import pandas as pd |
| 19 | + |
| 20 | +torch.manual_seed(0) |
| 21 | +# PyTorch operations must use “deterministic” algorithms. if not available throw RuntimeError |
| 22 | +# torch.use_deterministic_algorithms(True) |
| 23 | + |
| 24 | + |
| 25 | +from datetime import datetime |
| 26 | + |
| 27 | +save_folder_address = "inference"+str(datetime.now()).replace(" ", "__") |
| 28 | + |
| 29 | + |
| 30 | +# In[30]: |
| 31 | + |
| 32 | + |
| 33 | +print("haha") |
| 34 | + |
| 35 | + |
| 36 | +# In[31]: |
| 37 | + |
| 38 | + |
| 39 | +features_names = ["maingloss", "domgloss", "ndomgloss", "domreloc", "ndomreloc", |
| 40 | + "domhandrelocx", "domhandrelocy", "domhandrelocz", "domhandrelocax", |
| 41 | + "domhandrelocay", "domhandrelocaz", "domhandrelocsx", "domhandrelocsy", "domhandrelocsz", |
| 42 | + "domhandrotx", "domhandroty", "domhandrotz", |
| 43 | + "ndomhandrelocx", "ndomhandrelocy", "ndomhandrelocz", "ndomhandrelocax", |
| 44 | + "ndomhandrelocay", "ndomhandrelocaz", "ndomhandrelocsx", "ndomhandrelocsy", "ndomhandrelocsz", |
| 45 | + "ndomhandrotx", "ndomhandroty", "ndomhandrotz"] |
| 46 | + |
| 47 | +directory = "mms-subset91" |
| 48 | +text_directory = "annotations-full/annotations" |
| 49 | +data_list = [] |
| 50 | +for filename in os.listdir(directory): |
| 51 | + f = os.path.join(directory, filename) |
| 52 | + df = pd.read_csv(f) |
| 53 | + |
| 54 | + filenumber = filename.split(".")[0] |
| 55 | + |
| 56 | + text_address = os.path.join(text_directory, filenumber, "gebaerdler.Text_Deutsch.annotation~") |
| 57 | + file = open(text_address, encoding='latin-1') |
| 58 | + lines = file.readlines() |
| 59 | + text_line = "" |
| 60 | + for i, text_data in enumerate(lines): |
| 61 | + if i>0: |
| 62 | + text_line = text_line + " " + text_data.replace("\n", "").split(";")[2] |
| 63 | + else: |
| 64 | + text_line = text_line + text_data.replace("\n", "").split(";")[2] |
| 65 | + |
| 66 | + data_dict = {"file_ID":filenumber, "text": text_line} |
| 67 | + for feature in features_names: |
| 68 | + if feature == "domgloss" or feature == "ndomgloss": |
| 69 | + temp = df[feature].copy() |
| 70 | + data_dict[feature] = [data_dict["maingloss"][i] if pd.isnull(token) else token for i,token in enumerate(temp)] |
| 71 | + else: |
| 72 | + data_dict[feature] = df[feature].tolist() |
| 73 | + data_list.append(data_dict) |
| 74 | + |
| 75 | + |
| 76 | +# data_list is a list of dictionaries\ |
| 77 | +# each dictianry corresponds to a data sample in the dataset\ |
| 78 | +# file_ID is the file number, text is the german sentence, and the rest are all a list of the same length containing different values of gloss, boolean, and real value numbers. |
| 79 | + |
| 80 | +# In[32]: |
| 81 | + |
| 82 | + |
| 83 | +boolean_map = {"yes": 1, "no": 0} |
| 84 | +for data in data_list: |
| 85 | + data["domreloc"] = [boolean_map[value] for value in data["domreloc"]] |
| 86 | + data["ndomreloc"] = [boolean_map[value] for value in data["ndomreloc"]] |
| 87 | +# data["shoulders"] = [boolean_map[value] for value in data["shoulders"]] |
| 88 | + |
| 89 | + |
| 90 | +# In[33]: |
| 91 | + |
| 92 | + |
| 93 | +def build_German_vocab(data_list, tokenizer): |
| 94 | + """ |
| 95 | + a function to build vocabulary |
| 96 | +
|
| 97 | + :param filepath: file path of the text file |
| 98 | + :param tokenizer: tokenizer related to the text file language |
| 99 | + :return: torchtext vocab of a particular language |
| 100 | + """ |
| 101 | + counter = Counter() |
| 102 | + for data in data_list: |
| 103 | + tokenized_text = tokenizer(data["text"]) |
| 104 | + counter.update(tokenized_text) |
| 105 | + return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) |
| 106 | + |
| 107 | + |
| 108 | +# In[34]: |
| 109 | + |
| 110 | + |
| 111 | +def build_gloss_vocab(data_list, gloss_name): |
| 112 | + """ |
| 113 | + a function to build vocabulary |
| 114 | +
|
| 115 | + :param filepath: file path of the text file |
| 116 | + :param tokenizer: tokenizer related to the text file language |
| 117 | + :return: torchtext vocab of a particular language |
| 118 | + """ |
| 119 | + counter = Counter() |
| 120 | + for data in data_list: |
| 121 | + counter.update(data[gloss_name]) |
| 122 | + return vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>']) |
| 123 | + |
| 124 | + |
| 125 | +# In[35]: |
| 126 | + |
| 127 | + |
| 128 | +# get spacy tokenizer for German text |
| 129 | +de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm') |
| 130 | + |
| 131 | +de_vocab = build_German_vocab(data_list, de_tokenizer) |
| 132 | +gl_vocab = build_gloss_vocab(data_list, "maingloss") |
| 133 | +dom_vocab = build_gloss_vocab(data_list, "domgloss") |
| 134 | +ndom_vocab = build_gloss_vocab(data_list, "ndomgloss") |
| 135 | + |
| 136 | + |
| 137 | +# I added this two lines because some tokens from the validation and test are not in the train set. |
| 138 | + |
| 139 | +# In[36]: |
| 140 | + |
| 141 | + |
| 142 | +UNK_IDX = de_vocab['<unk>'] |
| 143 | +de_vocab.set_default_index(UNK_IDX) |
| 144 | +gl_vocab.set_default_index(UNK_IDX) |
| 145 | +dom_vocab.set_default_index(UNK_IDX) |
| 146 | +ndom_vocab.set_default_index(UNK_IDX) |
| 147 | + |
| 148 | + |
| 149 | +# In[37]: |
| 150 | + |
| 151 | + |
| 152 | +# divide train and test here |
| 153 | +import math |
| 154 | +import random |
| 155 | + |
| 156 | +div = math.floor(len(data_list)*0.75) |
| 157 | + |
| 158 | +data_list_copy = data_list.copy() |
| 159 | + |
| 160 | +random.seed(1) |
| 161 | +random.shuffle(data_list_copy) |
| 162 | + |
| 163 | +test_data_raw = data_list_copy[div:len(data_list)] |
| 164 | +train_data_raw = data_list_copy[0:div] |
| 165 | + |
| 166 | + |
| 167 | +# In[38]: |
| 168 | + |
| 169 | + |
| 170 | +import json |
| 171 | + |
| 172 | +with open("test_data.json", "w") as outfile: |
| 173 | + json.dump(test_data_raw, outfile) |
| 174 | + |
| 175 | +with open("train_data.json", "w") as outfile: |
| 176 | + json.dump(train_data_raw, outfile) |
| 177 | + |
| 178 | + |
| 179 | +# In[39]: |
| 180 | + |
| 181 | + |
| 182 | +with open('test_data.json', 'r') as openfile: |
| 183 | + test_data_raw = json.load(openfile) |
| 184 | + |
| 185 | +print(len(test_data_raw)) |
| 186 | + |
| 187 | + |
| 188 | +with open('train_data.json', 'r') as openfile: |
| 189 | + train_data_raw = json.load(openfile) |
| 190 | + |
| 191 | +print(len(train_data_raw)) |
| 192 | + |
| 193 | + |
| 194 | +# In[8]: |
| 195 | + |
| 196 | + |
| 197 | +# with open('../Downloads/test_data.json', 'r') as openfile: |
| 198 | +# json_object_server = json.load(openfile) |
| 199 | + |
| 200 | +# print(len(json_object_server)) |
| 201 | + |
| 202 | + |
| 203 | +# In[10]: |
| 204 | + |
| 205 | + |
| 206 | +# for item in json_object_server: |
| 207 | +# print(item['file_ID']) |
| 208 | +# print(item["text"]) |
| 209 | +# print(item["maingloss"]) |
| 210 | + |
| 211 | + |
| 212 | +# In[ ]: |
| 213 | + |
| 214 | + |
| 215 | + |
| 216 | + |
0 commit comments