Skip to content

Commit be064aa

Browse files
authored
Merge pull request #2 from MaithriRao/sentence-piece-tokenization
Different data split strategy
2 parents 19ac968 + 4bd5328 commit be064aa

File tree

7 files changed

+1110
-0
lines changed

7 files changed

+1110
-0
lines changed

data_split/1_fold/only_gloss.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
2+
import math
3+
import pickle
4+
import torchtext
5+
import torch
6+
import torch.nn as nn
7+
from torchtext.data.utils import get_tokenizer
8+
from torchtext.vocab import vocab
9+
from torch import Tensor
10+
import io
11+
import time
12+
import os
13+
import pandas as pd
14+
import json
15+
from pathlib import Path
16+
from datetime import datetime
17+
from torch.utils.data import DataLoader
18+
from typing import List
19+
from sacrebleu.metrics import BLEU
20+
import numpy as np
21+
from .. import datasets
22+
from ..model import Model
23+
from ..utils import Tokenization
24+
from sklearn.model_selection import train_test_split
25+
26+
27+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28+
BATCH_SIZE = 128
29+
30+
def train_and_evaluate(ds, tokenization, augment):
31+
32+
if tokenization == Tokenization.SOURCE_ONLY:
33+
tokenization_dir = "source_only"
34+
elif tokenization == Tokenization.SOURCE_TARGET:
35+
tokenization_dir = "source_target"
36+
else:
37+
raise ValueError("Invalid tokenization value")
38+
39+
if not augment:
40+
augment_dir = "original_data"
41+
else:
42+
augment_dir = "aug_data"
43+
44+
#time_dir = str(datetime.now()).replace(" ", "__")
45+
46+
save_folder = os.path.join("data_split/1_fold", tokenization_dir, augment_dir, "onlyGloss")
47+
save_file_path = os.path.join(save_folder, "result")
48+
Path(save_folder).mkdir(parents=True, exist_ok=True)
49+
50+
model = Model(ds, augment)
51+
52+
(original, modified, full) = ds
53+
(tokenizer_original, vocab_original, sentences_original) = original
54+
(tokenizer_full, vocab_full, sentences_full) = full
55+
(source_text_full, target_gloss_full) = sentences_full
56+
(source_text_original, target_gloss_original) = sentences_original
57+
58+
59+
if augment:
60+
source_train, source_test, target_train, target_test = train_test_split(source_text_full, target_gloss_full, test_size=0.25, random_state = 42)
61+
else:
62+
source_train, source_test, target_train, target_test = train_test_split(source_text_original, target_gloss_original, test_size=0.25, random_state = 42)
63+
64+
65+
train_data = model.data_process(source_train, target_train, tokenization)
66+
# test_data = model.data_process(source_test, target_test, tokenization)
67+
68+
train_iter = DataLoader(train_data, batch_size=BATCH_SIZE,
69+
shuffle=True, collate_fn=model.generate_batch)
70+
71+
# test_iter = DataLoader(test_data, batch_size=BATCH_SIZE,
72+
# shuffle=True, collate_fn=generate_batch)
73+
NUM_EPOCHS = 1000
74+
loss_graf = []
75+
76+
transformer = model.create_transformer()
77+
transformer = transformer.to(device)
78+
79+
optimizer = torch.optim.Adam(
80+
transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9
81+
)
82+
83+
train_log = open(save_file_path+ f"_train_log.txt", 'w')
84+
85+
best_epoch = 0
86+
87+
for epoch in range(1, NUM_EPOCHS+1):
88+
start_time = time.time()
89+
90+
train_loss = model.train_epoch(transformer, train_iter, optimizer)
91+
if not augment:
92+
train_loss = train_loss.tolist()
93+
94+
end_time = time.time()
95+
log = "Epoch: " + str(epoch)+", Train loss: "+ str(train_loss)+" Epoch duration "+ str(end_time - start_time)+"\n"
96+
train_log.write(log)
97+
if epoch>1 and train_loss < min(loss_graf):
98+
torch.save(transformer.state_dict(), save_file_path+f"_best_model.pt")
99+
log = "min so far is at epoch: "+ str(epoch)+"\n"
100+
train_log.write(log)
101+
best_epoch = epoch
102+
103+
loss_graf.append(train_loss)
104+
105+
log = "best epoch is: "+ str(best_epoch)
106+
train_log.write(log)
107+
train_log.close()
108+
109+
110+
torch.save(transformer.state_dict(), save_file_path+f"_last_model.pt")
111+
112+
# Evaluation
113+
ground_truth = []
114+
hypothesis = []
115+
num_P_T = 0
116+
num_T_P = 0
117+
num_e = 0
118+
119+
for de_text, gl_text in zip(source_test, target_test):
120+
if tokenization == Tokenization.SOURCE_TARGET:
121+
source_entry = de_text[1]
122+
target_entry = gl_text[1]
123+
124+
print(f"Source Sententence : {source_entry}")
125+
print(f"Target Sententence : {target_entry}")
126+
127+
gl_pred = model.translate(transformer, source_entry, model.text_vocab, model.mms_vocab, model.text_tokenizer, tokenization)
128+
print(f"gloss prediction : {gl_pred}")
129+
130+
translated_sentence = ""
131+
for char in gl_pred:
132+
if char == "▁":
133+
translated_sentence += " "
134+
elif char != " ":
135+
translated_sentence += char
136+
137+
translated_sentence = translated_sentence.strip()
138+
139+
140+
print(f"translated_sentence: {translated_sentence}")
141+
142+
ground_truth.append(target_entry)
143+
hypothesis.append(translated_sentence)
144+
145+
P = len(translated_sentence.split())
146+
T = len(target_entry.split())
147+
148+
elif tokenization == Tokenization.SOURCE_ONLY:
149+
source_entry = de_text[1]
150+
target_entry = "".join(gl_text[1])
151+
152+
print(f"Source Sententence: {source_entry}")
153+
print(f"Target Sententence: {target_entry}")
154+
155+
gl_pred = model.translate(transformer, source_entry, model.text_vocab, model.mms_vocab, model.text_tokenizer, tokenization)
156+
print(f"Predicted gloss : {gl_pred}")
157+
158+
ground_truth.append(target_entry)
159+
hypothesis.append(gl_pred)
160+
161+
P = len(gl_pred.split())
162+
T = len(target_entry.split())
163+
164+
else:
165+
raise ValueError("Invalid tokenization value")
166+
167+
if P > T:
168+
print("P:", P)
169+
num_P_T += 1
170+
elif T > P:
171+
print("T:", T)
172+
num_T_P += 1
173+
else:
174+
num_e += 1
175+
176+
with open(save_file_path + f"_outputs.txt", "w") as f:
177+
f.write(f"P>T: {num_P_T}\n")
178+
f.write(f"T>P: {num_T_P}\n")
179+
f.write(f"equal: {num_e}\n")
180+
181+
refs = [ground_truth]
182+
bleu = BLEU()
183+
result = bleu.corpus_score(hypothesis, refs)
184+
f.write(f"BLEU score for maingloss: {result}\n")
185+
186+
if __name__ == "__main__":
187+
import sys
188+
189+
if len(sys.argv) != 2:
190+
print("Usage: python only_gloss.py [--source-only|--source-target]")
191+
sys.exit(1)
192+
193+
if sys.argv[1] == "--source-only":
194+
print("Using source only")
195+
tokenization = Tokenization.SOURCE_ONLY
196+
elif sys.argv[1] == "--source-target":
197+
print("Using source and target")
198+
tokenization = Tokenization.SOURCE_TARGET
199+
else:
200+
print("You have to specify either --source-only or --source-target as an argument.")
201+
sys.exit(1)
202+
203+
ds = datasets.read(tokenization)
204+
print("Original data :")
205+
train_and_evaluate(ds, tokenization, augment=False)
206+
207+
print("Augmented data:")
208+
train_and_evaluate(ds, tokenization, augment=True)
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
2+
import math
3+
import pickle
4+
import torchtext
5+
import torch
6+
import torch.nn as nn
7+
from torchtext.data.utils import get_tokenizer
8+
from torchtext.vocab import vocab
9+
from torch import Tensor
10+
import io
11+
import time
12+
import os
13+
import pandas as pd
14+
import json
15+
from pathlib import Path
16+
from datetime import datetime
17+
from torch.utils.data import DataLoader
18+
from typing import List
19+
from sacrebleu.metrics import BLEU
20+
import numpy as np
21+
from .. import datasets
22+
from ..model import Model
23+
from ..utils import Tokenization
24+
from sklearn.model_selection import train_test_split
25+
26+
27+
28+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29+
BATCH_SIZE = 128
30+
31+
def best_model(ds, tokenization, augment):
32+
33+
if tokenization == Tokenization.SOURCE_ONLY:
34+
tokenization_dir = "source_only"
35+
elif tokenization == Tokenization.SOURCE_TARGET:
36+
tokenization_dir = "source_target"
37+
else:
38+
raise ValueError("Invalid tokenization value")
39+
40+
model = Model(ds, augment)
41+
42+
(original, modified, full) = ds
43+
(tokenizer_original, vocab_original, sentences_original) = original
44+
(tokenizer_full, vocab_full, sentences_full) = full
45+
(source_text_full, target_gloss_full) = sentences_full
46+
(source_text_original, target_gloss_original) = sentences_original
47+
48+
49+
if augment:
50+
source_train, source_test, target_train, target_test = train_test_split(source_text_full, target_gloss_full, test_size=0.25, random_state = 42)
51+
else:
52+
source_train, source_test, target_train, target_test = train_test_split(source_text_original, target_gloss_original, test_size=0.25, random_state = 42)
53+
54+
transformer = model.create_transformer()
55+
transformer = transformer.to(device)
56+
57+
augment_or_original_dir = "aug_data" if augment else "original_data"
58+
59+
save_folder_path = os.path.join("data_split", "1_fold", tokenization_dir, augment_or_original_dir, "onlyGloss")
60+
model_file_path = os.path.join(save_folder_path, "result_best_model.pt")
61+
62+
transformer.load_state_dict(torch.load(model_file_path))
63+
64+
ground_truth = []
65+
hypothesis = []
66+
preds_file = open(save_folder_path+"_predictions.txt", "w")
67+
68+
num_P_T = 0
69+
num_T_P = 0
70+
num_e = 0
71+
72+
for de_text, gl_text in zip(source_test, target_test):
73+
if tokenization == Tokenization.SOURCE_TARGET:
74+
source_entry = de_text[1]
75+
target_entry = gl_text[1]
76+
77+
print(f"Source Sententence : {source_entry}")
78+
print(f"Target Sententence : {target_entry}")
79+
80+
gl_pred = model.translate(transformer, source_entry, model.text_vocab, model.mms_vocab, model.text_tokenizer, tokenization)
81+
print(f"gloss prediction : {gl_pred}")
82+
83+
translated_sentence = ""
84+
for char in gl_pred:
85+
if char == "▁":
86+
translated_sentence += " "
87+
elif char != " ":
88+
translated_sentence += char
89+
90+
translated_sentence = translated_sentence.strip()
91+
92+
93+
print(f"translated_sentence: {translated_sentence}")
94+
95+
ground_truth.append(target_entry)
96+
hypothesis.append(translated_sentence)
97+
98+
P = len(translated_sentence.split())
99+
T = len(target_entry.split())
100+
101+
elif tokenization == Tokenization.SOURCE_ONLY:
102+
source_entry = de_text[1]
103+
target_entry = "".join(gl_text[1])
104+
105+
print(f"Source Sententence: {source_entry}")
106+
print(f"Target Sententence: {target_entry}")
107+
108+
gl_pred = model.translate(transformer, source_entry, model.text_vocab, model.mms_vocab, model.text_tokenizer, tokenization)
109+
print(f"Predicted gloss : {gl_pred}")
110+
111+
ground_truth.append(target_entry)
112+
hypothesis.append(gl_pred)
113+
114+
P = len(gl_pred.split())
115+
T = len(target_entry.split())
116+
117+
else:
118+
raise ValueError("Invalid tokenization value")
119+
120+
if P > T:
121+
print("P:", P)
122+
num_P_T += 1
123+
elif T > P:
124+
print("T:", T)
125+
num_T_P += 1
126+
else:
127+
num_e += 1
128+
129+
130+
preds_file.write(str(de_text[0])+"\n")
131+
preds_file.write(de_text[1]+"\n")
132+
preds_file.write(target_entry+"\n")
133+
preds_file.write(gl_pred+"\n")
134+
preds_file.write("************************************\n")
135+
preds_file.close()
136+
137+
138+
139+
f = open(save_folder_path+"_outputs.txt","w")
140+
141+
line = "P>T: "+ str(num_P_T) +"\n"
142+
f.write(line)
143+
144+
line = "T>P: "+ str(num_T_P) +"\n"
145+
f.write(line)
146+
147+
line = "equal: "+ str(num_e) +"\n"
148+
f.write(line)
149+
150+
from sacrebleu.metrics import BLEU
151+
152+
# use the lists ground_truth, hypothesis
153+
refs = [ground_truth]
154+
155+
bleu = BLEU()
156+
157+
result = bleu.corpus_score(hypothesis, refs)
158+
159+
line = "BLEU score for maingloss: "+str(result)+"\n"
160+
f.write(line)
161+
162+
f.close()
163+
164+
if __name__ == "__main__":
165+
import sys
166+
167+
if len(sys.argv) != 2:
168+
print("Usage: python only_gloss_inference_best.py [--source-only|--source-target]")
169+
sys.exit(1)
170+
171+
if sys.argv[1] == "--source-only":
172+
print("Using source only")
173+
tokenization = Tokenization.SOURCE_ONLY
174+
elif sys.argv[1] == "--source-target":
175+
print("Using source and target")
176+
tokenization = Tokenization.SOURCE_TARGET
177+
else:
178+
print("You have to specify either --source-only or --source-target as an argument.")
179+
sys.exit(1)
180+
181+
ds = datasets.read(tokenization)
182+
print("Original data :")
183+
best_model(ds, tokenization, augment=False)
184+
185+
print("Augmented data:")
186+
best_model(ds, tokenization, augment=True)

0 commit comments

Comments
 (0)