Skip to content

Commit a987e7d

Browse files
authored
Enable scoring in inference engine py (#2543)
* added scoring method for InferenceEnginePY * added wikitext-2 benchmark * use a fixed window size of 512 tokens in the wikitext-2 ppl benchmark
1 parent e37d5a5 commit a987e7d

File tree

8 files changed

+45095
-0
lines changed

8 files changed

+45095
-0
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import copy
2+
import json
3+
import numpy as np
4+
import os
5+
import pyonmttok
6+
import time
7+
from onmt.constants import CorpusTask, DefaultTokens
8+
from onmt.inference_engine import InferenceEnginePY
9+
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
10+
import onmt.opts as opts
11+
from onmt.utils.logging import init_logger
12+
from onmt.utils.parse import ArgumentParser
13+
from onmt.utils.misc import use_gpu, set_random_seed
14+
from onmt.transforms import get_transforms_cls
15+
16+
17+
def compute_file_ppl(output_filename):
18+
with open(output_filename, "r") as f:
19+
run_results = json.load(f)
20+
nlls = []
21+
lengths = []
22+
for i, _res in enumerate(run_results["scored_results"]):
23+
print(_res)
24+
nlls.append(_res[0])
25+
lengths.append(_res[1])
26+
file_ppl = np.exp(-np.sum(nlls) / np.sum(lengths))
27+
print("wikitext-2 ppl: %.4f" % file_ppl)
28+
29+
30+
def tokenize_dataset(opt, context_length):
31+
print("Tokenization...")
32+
33+
# Prepare the dataset
34+
x = open(opt.src, "r").readlines()
35+
x = [_x.rstrip("\n") for _x in x]
36+
y = DefaultTokens.SEP.join(x)
37+
38+
with open(opt.src + ".temp", "w") as writer:
39+
writer.write(y)
40+
41+
# ########################## #
42+
# Build the dataset iterator #
43+
# ########################## #
44+
45+
# Build the vocab
46+
vocab_path_in = "/nas-labs/LM/big_llms/llama/7B/llama.vocab"
47+
voc = []
48+
with open(vocab_path_in, "r", encoding="utf-8") as reader:
49+
for line in reader:
50+
line = line.strip("\n")
51+
voc.append(line)
52+
vocabs = {}
53+
src_vocab = pyonmttok.build_vocab_from_tokens(voc)
54+
vocabs["src"] = src_vocab
55+
vocabs["tgt"] = src_vocab
56+
vocabs["data_task"] = "lm"
57+
vocabs["decoder_start_token"] = "<s>"
58+
59+
transforms_cls = get_transforms_cls(opt._all_transform)
60+
61+
new_opt = opt
62+
new_opt.gpu = -1
63+
new_opt.parallel_mode = "data_parallel"
64+
new_opt.src = opt.src + ".temp"
65+
66+
dataset_iter = build_dynamic_dataset_iter(
67+
new_opt, transforms_cls, vocabs, task=CorpusTask.INFER, device_id=-1
68+
)
69+
70+
input_tokens = []
71+
for batch, i in dataset_iter:
72+
for i in range(batch["src"].size()[0]):
73+
start_ids = batch["src"][i, :, 0].cpu().numpy().tolist()
74+
input_tokens += [
75+
vocabs["src"].lookup_index(id)
76+
for id in start_ids
77+
if id != vocabs["src"].lookup_token(DefaultTokens.PAD)
78+
]
79+
80+
def make_chunks(lst, n):
81+
"""Yield successive n-sized chunks from lst."""
82+
for i in range(0, len(lst), n):
83+
yield lst[i : i + n]
84+
85+
# #################### #
86+
# Tokenize the dataset #
87+
# ################### #
88+
with open(opt.src + f".tokenized.context_{context_length}", "w") as writer:
89+
for _chunk in make_chunks(input_tokens, context_length - 1):
90+
writer.write(" ".join(_chunk) + "\n")
91+
print(len(_chunk))
92+
93+
print("Done !")
94+
95+
z = open(opt.src + f".tokenized.context_{context_length}", "r").readlines()
96+
print(len(z[0].split(" ")))
97+
98+
99+
def evaluate(opt):
100+
"""Score the wikitext2 testset"""
101+
ArgumentParser.validate_translate_opts(opt)
102+
ArgumentParser._get_all_transform_translate(opt)
103+
ArgumentParser._validate_transforms_opts(opt)
104+
ArgumentParser.validate_translate_opts_dynamic(opt)
105+
logger = init_logger(opt.log_file)
106+
set_random_seed(opt.seed, use_gpu(opt))
107+
108+
run_results = {}
109+
dir_name = os.path.dirname(opt.models[0])
110+
base_name = os.path.basename(opt.models[0])
111+
112+
output_filename = os.path.join(
113+
dir_name, "wikitext-2_benchmark_%s.json" % base_name[:-3]
114+
)
115+
116+
# Build the translator (along with the model.
117+
engine_opt = copy.copy(opt)
118+
engine_opt._all_transform = []
119+
engine = InferenceEnginePY(engine_opt)
120+
121+
# Tokenize the dataset.
122+
opt.src = "wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw"
123+
tokenize_dataset(opt, context_length=512)
124+
125+
# Score the tokeznized dataset
126+
engine.opt.src = opt.src + f".tokenized.context_{512}"
127+
start_time = time.time()
128+
scored_results = engine.score_file()
129+
engine.terminate()
130+
run_results["scored_results"] = scored_results
131+
132+
with open(output_filename, "w") as f:
133+
json.dump(run_results, f, ensure_ascii=False, indent=2)
134+
135+
compute_file_ppl(output_filename)
136+
137+
end_time = time.time()
138+
logger.info("total run time %.2f" % (end_time - start_time))
139+
140+
141+
def _get_parser():
142+
parser = ArgumentParser(description="run_wikitext-2_benchmark.py")
143+
opts.config_opts(parser)
144+
opts.translate_opts(parser, dynamic=True)
145+
return parser
146+
147+
148+
def main():
149+
parser = _get_parser()
150+
opt = parser.parse_args()
151+
evaluate(opt)
152+
153+
154+
if __name__ == "__main__":
155+
main()

eval_llm/WIKITEXT2/wikitext-2-raw-v1/wikitext-2-raw/wiki.test.raw

Lines changed: 4358 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)