Skip to content

Commit 19ac968

Browse files
authored
Merge pull request #1 from MaithriRao/data_augmentation
Add code for data augmentation
2 parents 8b6fbad + de81a44 commit 19ac968

File tree

7 files changed

+1051
-0
lines changed

7 files changed

+1051
-0
lines changed

augmented_data/const.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import os
2+
3+
4+
main_folder_text = "annotations_full/annotations"
5+
main_folder_mms = "mms-subset91"
6+
7+
sub_folders_text = sorted([f.path for f in os.scandir(main_folder_text) if f.is_dir()])
8+
9+
text_file_name = "gebaerdler.Text_Deutsch.annotation~"

augmented_data/dataset.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import os
2+
import csv
3+
4+
from .const import *
5+
6+
7+
def read_dataset_text():
8+
dataset = {}
9+
10+
for folder in sub_folders_text:
11+
file_path = os.path.join(folder, text_file_name)
12+
13+
if not os.path.exists(file_path):
14+
continue
15+
16+
folder_name = os.path.basename(os.path.dirname(file_path))
17+
18+
# Some files are encoded with ISO 8859-1, some are UTF-8.
19+
# Trying to work around this dataset by first trying UTF-8,
20+
# then ISO 8859-1 won't work because some files are valid UTF-8
21+
# even though they were encoded with ISO 8859-1.
22+
# To work around that, just hardcode which files contain UTF-8.
23+
# 0099 is just completely broken.
24+
if folder_name in ("0099"):
25+
continue
26+
if folder_name in ("0090", "0101", "0102"):
27+
encoding = 'utf-8'
28+
else:
29+
encoding = 'iso-8859-1'
30+
31+
with open(file_path, 'r', encoding=encoding) as f:
32+
lines = f.readlines()
33+
34+
parsed_file = []
35+
for line in lines:
36+
# print('>', line, '<', file_path)
37+
start_time, end_time, sentence, number = line.strip().split(";")
38+
assert number == "1", f"number is {number}"
39+
40+
parsed_file.append((start_time, end_time, sentence, number))
41+
#print(parsed_file)
42+
dataset[folder_name] = parsed_file
43+
44+
return dataset
45+
46+
47+
def read_dataset_mms():
48+
dataset = {}
49+
50+
for file_name in sorted(os.listdir(main_folder_mms)):
51+
file_path = os.path.join(main_folder_mms, file_name)
52+
file_number = file_name.rsplit('.', maxsplit=1)[0]
53+
54+
parsed_file = []
55+
with open(file_path, 'r', encoding='utf-8', newline='') as f:
56+
#print(f'reading file {file_path}')
57+
reader = csv.DictReader(f)
58+
for row in reader:
59+
parsed_file.append(row)
60+
61+
dataset[file_number] = parsed_file
62+
63+
return dataset
64+
65+
66+
def write_dataset_text(original_dataset, dataset, main_folder):
67+
for folder_name, file_contents in dataset.items():
68+
original_file_contents = original_dataset[folder_name]
69+
if original_file_contents == file_contents:
70+
continue
71+
file_folder = os.path.join(main_folder, folder_name)
72+
os.makedirs(file_folder, exist_ok = True)
73+
file_path = os.path.join(file_folder, text_file_name)
74+
with open(file_path, 'w', encoding='utf-8') as f:
75+
for row in file_contents:
76+
text_line = ";".join(row)
77+
f.write(text_line + '\n')
78+
79+
80+
def write_dataset_mms(original_dataset, dataset, main_folder):
81+
fieldnames = ['maingloss', 'framestart', 'frameend', 'duration', 'transition', 'domgloss', 'ndomgloss', 'domreloc', 'ndomreloc', 'headpos', 'headmov', 'cheecks', 'nose', 'mouthgest', 'mouthing', 'eyegaze', 'eyeaperture', 'eyebrows', 'neck', 'shoulders', 'torso', 'domhandrelocx', 'domhandrelocy', 'domhandrelocz', 'domhandrelocax', 'domhandrelocay', 'domhandrelocaz', 'domhandrelocsx', 'domhandrelocsy', 'domhandrelocsz', 'domhandrotx', 'domhandroty', 'domhandrotz', 'ndomhandrelocx', 'ndomhandrelocy', 'ndomhandrelocz', 'ndomhandrelocax', 'ndomhandrelocay', 'ndomhandrelocaz', 'ndomhandrelocsx', 'ndomhandrelocsy', 'ndomhandrelocsz', 'ndomhandrotx', 'ndomhandroty', 'ndomhandrotz']
82+
os.makedirs(main_folder, exist_ok = True)
83+
84+
for file_name, file_contents in dataset.items():
85+
original_file_contents = original_dataset[file_name]
86+
if original_file_contents == file_contents:
87+
continue
88+
file_path = os.path.join(main_folder, file_name + '.mms')
89+
with open(file_path, 'w', encoding='utf-8', newline='') as f:
90+
writer = csv.DictWriter(f, fieldnames=fieldnames)
91+
writer.writeheader()
92+
for row in file_contents:
93+
writer.writerow(row)

augmented_data/location.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Run with: python -m augmented_data.location
2+
3+
import nltk
4+
import spacy
5+
import random
6+
from .utils import replace_multiple, normalize_text_to_mms
7+
8+
9+
# nltk.download('punkt')
10+
# nltk.download('averaged_perceptron_tagger')
11+
# nltk.download('maxent_ne_chunker')
12+
# nltk.download('words')
13+
14+
nlp = spacy.load('de_core_news_sm')
15+
16+
17+
def replace_location_entities(dataset_text, dataset_mms):
18+
location_names = set()
19+
dataset_text_with_metadata = {}
20+
21+
excludes = {'Zuges', 'Alternativen', 'D', 'RE 77','A.', 'D.','Umsteigen','Weiteres',
22+
'Notbremse', 'Reservierungen','IC 2313','Sonderzug','Rhein'}
23+
24+
25+
for folder_name, file_content in dataset_text.items():
26+
file_with_metadata = []
27+
for (start_time, end_time, sentence, number) in file_content:
28+
sentences_to_analyze = sentence.strip().translate({ord(i): None for i in "„“"})
29+
# sentences_to_analyze = sentence.strip().replace("„", "").replace("“", "")
30+
doc = nlp(sentences_to_analyze)
31+
entities = [(ent.text, ent.label_) for ent in doc.ents]
32+
# print(entities)
33+
# print(folder_name, sentence)
34+
35+
for ent in doc.ents:
36+
if ent.label_ == 'LOC' and ent.text not in excludes:
37+
# Check if "Hauptbahnhof" is present in the entity text
38+
location = ent.text
39+
if 'Hauptbahnhof' in location:
40+
location = location.replace('Hauptbahnhof', '').strip() # Remove "Hauptbahnhof" and strip extra spaces
41+
location_names.add(location)
42+
# print('location_names are', location_names)
43+
file_with_metadata.append((start_time, end_time, sentence, number, entities))
44+
45+
dataset_text_with_metadata[folder_name] = file_with_metadata
46+
# print("file_with_metadata", dataset_text_with_metadata)
47+
48+
# print(location_names) # Finding all the locations
49+
50+
result_text = {}
51+
result_mms = {}
52+
53+
for folder_name, file_content in dataset_text_with_metadata.items():
54+
location_counts = {}
55+
for line_number, (start_time, end_time, sentence, number, entities) in enumerate(file_content):
56+
for (text, label) in entities:
57+
if label == 'LOC' and text not in excludes:
58+
if 'Hauptbahnhof' in text:
59+
text = text.replace('Hauptbahnhof', '').strip()
60+
location_counts[text] = location_counts.get(text, 0) + 1 #counting the number of times same location appears in a file
61+
# print(f'WARNING: location {text} in file {folder_name} appears multiple times')
62+
63+
mapping = {}
64+
for location, count in location_counts.items():
65+
assert len(location_names) > 1, f'ERROR: only one location found'
66+
while True:
67+
new_location = random.choice(tuple(location_names))
68+
if new_location != location:
69+
break
70+
mapping[location] = new_location
71+
72+
73+
new_text_data = []
74+
for (start_time, end_time, sentence, number, entities) in file_content:
75+
sentence, _ = replace_multiple(sentence, mapping)
76+
new_text_data.append((start_time, end_time, sentence, number))
77+
result_text[folder_name] = new_text_data
78+
79+
replaced_counts = {}
80+
new_mms_data = []
81+
for row in dataset_mms[folder_name]:
82+
mapping_mms = dict((normalize_text_to_mms(k), normalize_text_to_mms(v)) for k, v in mapping.items())
83+
new_row = row.copy()
84+
word = row['maingloss']
85+
if word in mapping_mms:
86+
new_row['maingloss'] = mapping_mms[word]
87+
replaced_counts[word] = replaced_counts.get(word, 0) + 1
88+
new_mms_data.append(new_row)
89+
result_mms[folder_name] = new_mms_data
90+
91+
for location, count in location_counts.items():
92+
location_mms = normalize_text_to_mms(location)
93+
replaced_count = replaced_counts.get(location_mms, 0)
94+
if replaced_count != count:
95+
print(f'WARNING: replaced_count in file {folder_name} should be {count} but was {replaced_count}, trying to replace {location_mms}')
96+
97+
return (result_text, result_mms)
98+
99+
100+
101+
if __name__ == "__main__":
102+
from .dataset import *
103+
dataset_text = read_dataset_text()
104+
dataset_mms = read_dataset_mms()
105+
106+
result_text, result_mms = replace_location_entities(dataset_text, dataset_mms)
107+
write_dataset_text(dataset_text, result_text, main_folder = 'modified/location/text')
108+
write_dataset_mms(dataset_mms, result_mms, main_folder = 'modified/location/mms')

augmented_data/platform.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Run with: python -m augmented_data.platform
2+
3+
import re
4+
import random
5+
from enum import Enum, verify, UNIQUE
6+
from collections import Counter
7+
from .utils import replace_multiple, parse_num
8+
9+
10+
def replace_platform_entities(dataset_text, dataset_mms):
11+
all_platforms = set()
12+
dataset_text_with_metadata = {}
13+
14+
platform_pattern = r'\b(Gleis \d+)[a-z]?\b'
15+
16+
for folder_name, file in dataset_text.items():
17+
platforms_per_file = []
18+
for (start_time, end_time, sentence, number) in file:
19+
platforms_per_file += re.findall(platform_pattern, sentence)
20+
21+
all_platforms = all_platforms.union(set(platforms_per_file))
22+
platform_counts = Counter(platforms_per_file)
23+
dataset_text_with_metadata[folder_name] = (file, platform_counts)
24+
25+
result_text = {}
26+
result_mms = {}
27+
for folder_name, file_with_metadata in dataset_text_with_metadata.items():
28+
mapping = {}
29+
(file, platform_counts) = file_with_metadata
30+
all_platforms_tuple = tuple(all_platforms)
31+
for platform, count in platform_counts.items(): # TODO: think about randomly shuffling instead
32+
assert len(all_platforms) > 1, f'ERROR: only one platform found'
33+
while True:
34+
new_platform = random.choice(all_platforms_tuple)
35+
if new_platform != platform:
36+
break
37+
mapping[platform] = new_platform
38+
39+
new_text_data = []
40+
for start_time, end_time, sentence, number in file:
41+
sentence, _ = replace_multiple(sentence, mapping)
42+
new_text_data.append((start_time, end_time, sentence, number))
43+
result_text[folder_name] = new_text_data
44+
45+
@verify(UNIQUE)
46+
class State(Enum):
47+
NOT_FOUND = 0
48+
GLEIS = 1
49+
WECHSELN = 2
50+
NUM = 3
51+
52+
replaced_counts = {}
53+
new_mms_data = []
54+
state = State.NOT_FOUND
55+
for row in dataset_mms[folder_name]:
56+
new_row = row.copy()
57+
word = row['maingloss']
58+
if state == State.NOT_FOUND:
59+
if word == 'GLEIS':
60+
state = State.GLEIS
61+
elif state == State.GLEIS or state == State.WECHSELN:
62+
if state == State.GLEIS and word == 'WECHSELN':
63+
state = State.WECHSELN
64+
elif word.startswith('num:'):
65+
state = State.NUM
66+
else:
67+
print(f'WARNING: expected WECHSELN or num:, got {word} in file {folder_name}')
68+
state = State.NOT_FOUND
69+
70+
if state == State.NUM:
71+
num = parse_num(word, folder_name)
72+
old_gleis = f'Gleis {num}'
73+
new_gleis = mapping[old_gleis]
74+
print(f'Found {old_gleis} in file {folder_name}, replacing with {new_gleis}')
75+
new_num = new_gleis.removeprefix('Gleis ')
76+
new_row['maingloss'] = 'num:' + new_num
77+
replaced_counts[old_gleis] = replaced_counts.get(old_gleis, 0) + 1
78+
state = State.NOT_FOUND
79+
80+
new_mms_data.append(new_row)
81+
result_mms[folder_name] = new_mms_data
82+
83+
for platform, count in platform_counts.items():
84+
replaced_count = replaced_counts.get(platform, 0)
85+
if replaced_count != count:
86+
print(f'WARNING: replaced_count in file {folder_name} should be {count} but was {replaced_count}, trying to replace {platform}')
87+
88+
# print(result_text)
89+
# print(all_platforms)
90+
return (result_text, result_mms)
91+
92+
93+
94+
if __name__ == "__main__":
95+
from .dataset import *
96+
dataset_text = read_dataset_text()
97+
dataset_mms = read_dataset_mms()
98+
99+
result_text, result_mms = replace_platform_entities(dataset_text, dataset_mms)
100+
write_dataset_text(dataset_text, result_text, main_folder = 'modified/platform/text')
101+
write_dataset_mms(dataset_mms, result_mms, main_folder = 'modified/platform/mms')

0 commit comments

Comments
 (0)