-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_scratch.py
More file actions
119 lines (107 loc) · 3.81 KB
/
eval_scratch.py
File metadata and controls
119 lines (107 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from annotatedtransformer.utils import *
from annotatedtransformer.vocab import Tokenizer
from annotatedtransformer.transformer import make_model
from torchmetrics.text import BLEUScore
def check_outputs(
test_dataloader,
model,
vocab_src,
vocab_tgt,
n_examples=15,
pad_idx=2,
):
results = []
bleu = BLEUScore()
for idx, b in enumerate(test_dataloader):
print("\nExample %d ========\n" % idx)
rb = Batch(b[0], b[1], pad_idx)
src_tokens = [
vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
]
tgt_tokens = [
vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
]
print(
"Source Text (Input) : "
+ " ".join(src_tokens).replace("\n", "")
)
print(
"Target Text (Ground Truth) : "
+ " ".join(tgt_tokens).replace("\n", "")
)
model_out = greedy_decode(model, rb.src, rb.src_mask, 100, 0)[0]
out_tokens = [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
out_tokens = out_tokens[: out_tokens.index("</s>")]
print(
"Model Output : "
+ " ".join(out_tokens).replace("\n", "")
)
preds = " ".join(out_tokens)
targets = " ".join(tgt_tokens)
print(
"BLEU Score : {:0.4f}".format(
bleu([preds], [[targets]])
)
)
results.append((rb, src_tokens, tgt_tokens, out_tokens))
if idx == n_examples - 1:
break
return results
def calc_bleu(
device,
test_dataloader,
model,
vocab_tgt,
pad_idx=2,
):
from tqdm import tqdm
model.to(device)
model.eval()
itos = vocab_tgt.get_itos()
bleu_scores = []
bleu = BLEUScore()
for b in tqdm(test_dataloader):
rb = Batch(b[0].to(device), b[1].to(device), pad_idx)
outs = greedy_decode(model, rb.src, rb.src_mask, 72, 0)
batch_size = rb.src.size(0)
for i in range(batch_size):
tgt_tokens = [itos[x] for x in rb.tgt[i] if x != pad_idx]
out_tokens = [itos[x] for x in outs[i] if x != pad_idx]
out_tokens = out_tokens[: out_tokens.index("</s>")] if "</s>" in out_tokens else out_tokens
preds = " ".join(out_tokens)
targets = " ".join(tgt_tokens)
bleu_scores.append(bleu([preds], [[targets]]))
return bleu_scores
def main():
from train_scratch import create_dataloaders_Multi30k
lang_pair = ("de", "en")
lang_src, lang_tgt = lang_pair
vocab_src = Tokenizer(f"output/vocab/de-en/bpe/{lang_src}_8000.model")
vocab_tgt = Tokenizer(f"output/vocab/de-en/bpe/{lang_tgt}_8000.model")
print("Preparing Data ...")
_, _, test_dataloader = create_dataloaders_Multi30k(
vocab_src, vocab_tgt,
batch_size=1, language_pair=lang_pair,
)
print("Loading Trained Model ...")
model_ckpt = f"output/model/{lang_src}2{lang_tgt}/transformer_model.pt"
model = make_model(len(vocab_src), len(vocab_tgt), N=6)
model.load_state_dict(
torch.load(model_ckpt, map_location=torch.device("cpu"))
)
model.eval()
print("Checking Model Outputs:")
example_data = check_outputs(
test_dataloader, model, vocab_src, vocab_tgt, n_examples=10
)
print("\nCalculating BLEU Score on Testset ...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_, _, test_dataloader = create_dataloaders_Multi30k(
vocab_src, vocab_tgt,
batch_size=64, language_pair=lang_pair,
)
bleu_scores = calc_bleu(device, test_dataloader, model, vocab_tgt)
print("BLEU Score on testset: {:0.4f}".format(torch.tensor(bleu_scores).mean().item()))
if __name__ == "__main__":
main()