Skip to content

Commit 4b39cda

Browse files
committed
Add train name in data augmentation
1 parent 5f30f31 commit 4b39cda

File tree

1 file changed

+360
-0
lines changed

1 file changed

+360
-0
lines changed

augmented_data/train_name.py

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Run with: python -m augmented_data.train_name
2+
3+
import re
4+
import random
5+
from enum import Enum, verify, UNIQUE
6+
from collections import defaultdict
7+
from .utils import has_numbers, replace_multiple
8+
9+
# Exclude 0
10+
single_sign_numbers = set()
11+
for i in range(1, 20):
12+
single_sign_numbers.add(i)
13+
14+
for i in range(2, 10):
15+
single_sign_numbers.add(i * 10) # Tens
16+
single_sign_numbers.add(i * 11) # Repdigits
17+
18+
print(f'single_sign_numbers: {single_sign_numbers}')
19+
20+
class TrainData:
21+
def __init__(self):
22+
self.type_line = None
23+
self.train_type = None
24+
self.num_has_dashes = False
25+
self.num1_str = None
26+
self.num2_str = None
27+
self.num_int = None
28+
self.num1_line = None
29+
self.num2_line = None
30+
31+
def is_valid(self):
32+
if self.train_type is None:
33+
print(f'ERROR: train_type has not been set yet, num: {self.num_int}')
34+
return False
35+
if self.num_int < 0: # TODO: think about setting an upper limit maybe?
36+
print(f'ERROR: not a valid number: {self.num_int}')
37+
self.num_int = None
38+
return False
39+
return True
40+
41+
def set_train_type_mms(self, train_type, type_line):
42+
if train_type not in TrainType:
43+
raise Exception(f'ERROR: train_type is not a valid value: {train_type}')
44+
if self.train_type is not None:
45+
raise Exception('ERROR: train_type was already set')
46+
self.train_type = train_type
47+
self.type_line = type_line
48+
49+
def set_train_type_text(self, train_type):
50+
if train_type not in TrainType:
51+
raise Exception(f'ERROR: train_type is not a valid value: {train_type}')
52+
if self.train_type is not None:
53+
raise Exception('ERROR: train_type was already set')
54+
self.train_type = train_type
55+
56+
def set_num_mms(self, num_str, line_num):
57+
if self.num2_str is not None or self.num_has_dashes:
58+
raise Exception('ERROR: number was already set')
59+
if num_str == '':
60+
raise Exception('ERROR: number is empty')
61+
62+
if '-' in num_str:
63+
num_str = num_str.replace('-', '')
64+
self.num_has_dashes = True
65+
try:
66+
num_int = int(num_str)
67+
except ValueError:
68+
raise Exception(f'ERROR: number is not a number: {num_str}')
69+
if num_int < 0:
70+
raise Exception(f'ERROR: number is negative: {num_int}')
71+
72+
if self.num1_str is None:
73+
if not self.num_has_dashes and num_int not in single_sign_numbers:
74+
print(f'WARNING: num1 is not a valid single sign value: {num_str}')
75+
self.num1_str = num_str
76+
self.num1_line = line_num
77+
else:
78+
if num_str not in ['20', '30', '40', '50', '60', '70', '80', '90']:
79+
print(f'WARNING: num2 is not a valid value: {num_str}')
80+
if self.num_int not in range(1, 10):
81+
print(f'WARNING: num1 is not a valid value: {self.num1_str}')
82+
self.num2_str = num_str
83+
self.num2_line = line_num
84+
85+
if self.num_int is None:
86+
self.num_int = 0
87+
self.num_int += num_int
88+
89+
def set_num_text(self, num_str):
90+
if self.num1_str is not None:
91+
raise Exception('ERROR: number was already set')
92+
if num_str == '':
93+
raise Exception('ERROR: number is empty')
94+
self.num1_str = num_str
95+
try:
96+
num_int = int(num_str)
97+
except ValueError:
98+
raise Exception(f'ERROR: number is not a number: {num_str}')
99+
100+
self.num_int = num_int
101+
102+
def get_line_numbers(self):
103+
if self.type_line is None:
104+
raise Exception('ERROR: type has not been set yet')
105+
if self.num1_line is None:
106+
raise Exception('ERROR: number has not been set yet')
107+
return (self.type_line, self.num1_line, self.num2_line)
108+
109+
def get_num_int(self):
110+
if self.num_int is None:
111+
raise Exception('ERROR: number has not been set yet')
112+
return self.num_int
113+
114+
def get_train_type(self):
115+
if self.train_type is None:
116+
raise Exception(f'ERROR: train_type has not been set yet, num: {self.num_int}')
117+
return self.train_type
118+
119+
120+
@verify(UNIQUE)
121+
class TrainType(Enum):
122+
# Key is in the format of the text dataset
123+
# Value is in the format of the mms dataset
124+
UNKNOWN = "UNKNOWN"
125+
RE = "R-E"
126+
RB = "R-B"
127+
IC = "I-C"
128+
ICE = "ICE"
129+
130+
def pick_random_train_type(exclude=set()):
131+
exclude.add(TrainType.UNKNOWN)
132+
all_train_types = set(TrainType)
133+
filtered_train_types = all_train_types.difference(exclude)
134+
return random.choice(tuple(filtered_train_types))
135+
136+
def mms_train_num_to_str(num1_int, num2_int):
137+
num1_str = str(num1_int)
138+
if num1_int > 100: # Add dashes between the digits
139+
assert(num2_int is None)
140+
result = ''
141+
for digit in num1_str:
142+
if len(result) != 0:
143+
result += '-'
144+
result += digit
145+
return result
146+
return num1_str
147+
148+
def mms_train_type_to_str(train_type):
149+
assert train_type != TrainType.UNKNOWN
150+
if train_type == TrainType.ICE:
151+
return train_type.value
152+
return f'fa:{train_type.value}'
153+
154+
def assemble_train_text(train_type, num_int):
155+
return f'{train_type.name} {num_int}'
156+
157+
158+
def replace_train_entities(dataset_text, dataset_mms):
159+
@verify(UNIQUE)
160+
class State(Enum):
161+
NOT_FOUND = 0
162+
TRAIN = 1
163+
FIRST_NUM = 2
164+
SECOND_NUM = 3
165+
TRAIN_FOUND = 4
166+
DONE_PARSING = 5
167+
168+
def process_train(state, train_data, train_positions):
169+
if state == State.TRAIN_FOUND:
170+
state = State.DONE_PARSING
171+
if not train_data.is_valid():
172+
return (state, train_positions)
173+
# Here we know that the train number is valid
174+
(type_line, num1_line, num2_line) = train_data.get_line_numbers()
175+
num_int = train_data.get_num_int()
176+
train_type = train_data.get_train_type()
177+
print(f'Found train number: {train_type.value} {num_int}, lines: {num1_line}, {num2_line} (file_number: {file_number})')
178+
train_position = (type_line, train_type, num1_line, num2_line, num_int)
179+
train_positions_list = train_positions.get(file_number, [])
180+
train_positions_list.append(train_position)
181+
train_positions[file_number] = train_positions_list
182+
return (state, train_positions)
183+
184+
result_mms = {}
185+
train_mappings = defaultdict(dict)
186+
train_counts = defaultdict(lambda: defaultdict(int))
187+
for file_number, file_contents in dataset_mms.items():
188+
state = State.NOT_FOUND
189+
train_data = TrainData()
190+
train_positions = {}
191+
word_after_train = None
192+
for line_num, row in enumerate(file_contents):
193+
word = row['maingloss']
194+
if state == State.DONE_PARSING:
195+
state = State.NOT_FOUND
196+
if word_after_train is None:
197+
word_after_train = word
198+
if has_numbers(word_after_train):
199+
print(f'WARNING: found number in {word_after_train} after the train (file_number: {file_number})')
200+
word_after_train = None
201+
if state == State.NOT_FOUND:
202+
train_data = TrainData()
203+
train_detected = False
204+
if word in ['fa:R-E', 'fa:R-B', 'fa:I-C', 'ICE']:
205+
train_type = TrainType(word.removeprefix('fa:'))
206+
train_data.set_train_type_mms(train_type, line_num)
207+
state = State.TRAIN
208+
train_detected = True
209+
else:
210+
state = State.NOT_FOUND
211+
word_no_dashes = word.replace('-', '')
212+
regex_matches = re.match(r'.*(?:RE|RB|IC|ICE)$', word_no_dashes) is not None
213+
should_be_detected = False
214+
if regex_matches:
215+
should_be_detected = True
216+
if word.isupper() and len(word) >= 4:
217+
should_be_detected = False
218+
if train_detected != should_be_detected:
219+
if train_detected:
220+
print(f"WARNING: train type detection wrong: got {word} in file {file_number}")
221+
else:
222+
print(f"WARNING: train type detection incomplete: got {word} in file {file_number}")
223+
elif state == State.TRAIN:
224+
if word.startswith('num:'):
225+
state = State.FIRST_NUM
226+
train_data.set_num_mms(word.removeprefix('num:'), line_num)
227+
else:
228+
print(f'WARNING: expected num:, got {word} in file {file_number}')
229+
state = State.NOT_FOUND
230+
elif state == State.FIRST_NUM:
231+
if word.startswith('num:'):
232+
state = State.TRAIN_FOUND
233+
train_data.set_num_mms(word.removeprefix('num:'), line_num)
234+
else:
235+
state = State.TRAIN_FOUND
236+
word_after_train = word
237+
238+
(state, train_positions) = process_train(state, train_data, train_positions)
239+
if state in [State.TRAIN, State.FIRST_NUM, State.SECOND_NUM]:
240+
state = State.TRAIN_FOUND
241+
# Do it a second time after the loop in case the train is at the end of the file
242+
(state, train_positions) = process_train(state, train_data, train_positions)
243+
244+
245+
246+
new_mms_data = []
247+
for row in dataset_mms[file_number]:
248+
new_row = row.copy()
249+
new_mms_data.append(new_row)
250+
251+
for file_number, train_infos in train_positions.items():
252+
for train_info in train_infos:
253+
(type_line, old_type, num1_line, num2_line, old_num) = train_info
254+
old_train = (old_type, old_num)
255+
if old_train not in train_mappings[file_number]:
256+
new_num1 = None
257+
new_num2 = None
258+
# TODO: change the train type if number of lines starts the same
259+
assert type_line is not None
260+
if num2_line is not None:
261+
new_type = pick_random_train_type(exclude={TrainType.ICE})
262+
else:
263+
new_type = pick_random_train_type()
264+
if new_type == TrainType.ICE:
265+
assert num2_line is None
266+
new_num1 = random.randrange(1000, 100000)
267+
else:
268+
if num1_line is not None and num2_line is None:
269+
new_num1 = random.choice(tuple(single_sign_numbers))
270+
elif num1_line is not None and num2_line is not None:
271+
new_num1 = random.randrange(1, 10)
272+
if num2_line is not None:
273+
new_num2 = random.choice([20, 30, 40, 50, 60, 70, 80, 90])
274+
new_train = (new_type, new_num1, new_num2)
275+
train_mappings[file_number][old_train] = new_train
276+
277+
train_counts[file_number][old_train] += 1
278+
new_train = train_mappings[file_number][old_train]
279+
280+
(new_type, new_num1, new_num2) = new_train
281+
new_type_str = mms_train_type_to_str(new_type)
282+
new_mms_data[type_line]['maingloss'] = new_type_str
283+
284+
if new_num1 is not None:
285+
new_num1 = mms_train_num_to_str(new_num1, new_num2)
286+
new_mms_data[num1_line]['maingloss'] = f'num:{new_num1}'
287+
if new_num2 is not None:
288+
if num2_line < len(new_mms_data):
289+
new_mms_data[num2_line]['maingloss'] = f'num:{new_num2}'
290+
else:
291+
print(f'WARNING: num2_line {num2_line} not found in file {file_number}')
292+
for line_num, row in enumerate(new_mms_data):
293+
print(f'line {line_num}: {row["maingloss"]}')
294+
295+
result_mms[file_number] = new_mms_data
296+
297+
print("train_mappings are", train_mappings)
298+
299+
300+
result_text = {}
301+
for folder_name, file_contents in dataset_text.items():
302+
train_name_pattern = r'\b((RE|RB|IC|ICE)\s*((?:\d-*){1,10}))\b'
303+
train_per_file = []
304+
all_distinct_trains_in_file = set()
305+
for (start_time, end_time, sentence, _number) in file_contents:
306+
matches = re.findall(train_name_pattern, sentence)
307+
for (whole_train, train_type, num) in matches:
308+
train_data = TrainData()
309+
train_data.set_num_text(num)
310+
train_type = TrainType[train_type]
311+
train_data.set_train_type_text(train_type)
312+
if not train_data.is_valid():
313+
continue
314+
num_int = train_data.get_num_int()
315+
assembled_train = assemble_train_text(train_type, num_int)
316+
assert assembled_train == whole_train, f'ERROR: assembled_train is {assembled_train}, whole_train is {whole_train}'
317+
all_distinct_trains_in_file.add((train_type, whole_train, num_int))
318+
319+
whole_train_mapping = {}
320+
train_mapping_str_to_tuple = {}
321+
for (old_type, old_whole_train, old_num) in all_distinct_trains_in_file:
322+
train_mapping = train_mappings[folder_name]
323+
if (old_type, old_num) not in train_mapping:
324+
print(f'WARNING: old_type {old_type} old_num {old_num} not found in train_mapping (file: {folder_name})')
325+
continue
326+
(new_type, new_num1, new_num2) = train_mapping[(old_type, old_num)]
327+
if new_num2 is not None:
328+
new_num1 += new_num2
329+
new_whole_train = assemble_train_text(new_type, new_num1)
330+
whole_train_mapping[old_whole_train] = new_whole_train
331+
train_mapping_str_to_tuple[old_whole_train] = (old_type, old_num)
332+
333+
334+
new_text_data = []
335+
replaced_counts = defaultdict(int)
336+
for (start_time, end_time, sentence, number) in file_contents:
337+
sentence, replaced_counts_per_line = replace_multiple(sentence, whole_train_mapping)
338+
for old_whole_train, counts in replaced_counts_per_line.items():
339+
old_whole_train_tuple = train_mapping_str_to_tuple[old_whole_train]
340+
replaced_counts[old_whole_train_tuple] += counts
341+
new_text_data.append((start_time, end_time, sentence, number))
342+
result_text[folder_name] = new_text_data
343+
344+
for train, count in train_counts[folder_name].items():
345+
replaced_count = replaced_counts[train]
346+
if replaced_count != count:
347+
print(f'WARNING: replaced_count in file {folder_name} should be {count} but was {replaced_count}, trying to replace {train}')
348+
349+
return (result_text, result_mms)
350+
351+
352+
353+
if __name__ == "__main__":
354+
from .dataset import *
355+
dataset_text = read_dataset_text()
356+
dataset_mms = read_dataset_mms()
357+
358+
result_text, result_mms = replace_train_entities(dataset_text, dataset_mms)
359+
write_dataset_text(result_text, main_folder = 'modified/train_name/text')
360+
write_dataset_mms(result_mms, main_folder = 'modified/train_name/mms')

0 commit comments

Comments
 (0)