Skip to content

Commit 1b5ca3d

Browse files
committed
yl: check compatibility lb
1 parent 08fc4a8 commit 1b5ca3d

File tree

6 files changed

+286
-2
lines changed

6 files changed

+286
-2
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,5 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161-
notebooks/test*
161+
notebooks/test*
162+
experiments/LongBench/pred*

experiments/LongBench/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,22 @@
1-
# Minimal version of LongBench
1+
# Minimal version of LongBench
2+
This is a VERY minimal implementation (with slightly modification) of LongBench.
3+
4+
You should be able to reproduce the results shown on our report.
5+
6+
Hardware: 1 * A100-80GB (2 for `Mixtral`)
7+
8+
example:
9+
10+
```
11+
bash scripts/run_longbench.py
12+
```
13+
14+
# Citation
15+
```
16+
@article{bai2023longbench,
17+
title={LongBench: A Bilingual, Multitask Benchmark for Long Context Understanding},
18+
author={Bai, Yushi and Lv, Xin and Zhang, Jiajie and Lyu, Hongchang and Tang, Jiankai and Huang, Zhidian and Du, Zhengxiao and Liu, Xiao and Zeng, Aohan and Hou, Lei and Dong, Yuxiao and Tang, Jie and Li, Juanzi},
19+
journal={arXiv preprint arXiv:2308.14508},
20+
year={2023}
21+
}
22+
```

experiments/LongBench/eval.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import json
3+
import argparse
4+
import numpy as np
5+
6+
from metrics import (
7+
qa_f1_score,
8+
rouge_zh_score,
9+
qa_f1_zh_score,
10+
rouge_score,
11+
classification_score,
12+
retrieval_score,
13+
retrieval_zh_score,
14+
count_score,
15+
code_sim_score,
16+
)
17+
18+
dataset2metric = {
19+
"narrativeqa": qa_f1_score,
20+
"qasper": qa_f1_score,
21+
"multifieldqa_en": qa_f1_score,
22+
"multifieldqa_zh": qa_f1_zh_score,
23+
"hotpotqa": qa_f1_score,
24+
"2wikimqa": qa_f1_score,
25+
"musique": qa_f1_score,
26+
"dureader": rouge_zh_score,
27+
"gov_report": rouge_score,
28+
"qmsum": rouge_score,
29+
"multi_news": rouge_score,
30+
"vcsum": rouge_zh_score,
31+
"trec": classification_score,
32+
"triviaqa": qa_f1_score,
33+
"samsum": rouge_score,
34+
"lsht": classification_score,
35+
"passage_retrieval_en": retrieval_score,
36+
"passage_count": count_score,
37+
"passage_retrieval_zh": retrieval_zh_score,
38+
"lcc": code_sim_score,
39+
"repobench-p": code_sim_score,
40+
}
41+
42+
def parse_args(args=None):
43+
parser = argparse.ArgumentParser()
44+
parser.add_argument('--model', type=str, default=None)
45+
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
46+
return parser.parse_args(args)
47+
48+
def scorer_e(dataset, predictions, answers, lengths, all_classes):
49+
scores = {"0-4k": [], "4-8k": [], "8k+": []}
50+
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
51+
score = 0.
52+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
53+
prediction = prediction.lstrip('\n').split('\n')[0]
54+
for ground_truth in ground_truths:
55+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
56+
if length < 4000:
57+
scores["0-4k"].append(score)
58+
elif length < 8000:
59+
scores["4-8k"].append(score)
60+
else:
61+
scores["8k+"].append(score)
62+
for key in scores.keys():
63+
scores[key] = round(100 * np.mean(scores[key]), 2)
64+
return scores
65+
66+
def scorer(dataset, predictions, answers, all_classes):
67+
total_score = 0.
68+
for (prediction, ground_truths) in zip(predictions, answers):
69+
score = 0.
70+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
71+
prediction = prediction.lstrip('\n').split('\n')[0]
72+
for ground_truth in ground_truths:
73+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
74+
total_score += score
75+
return round(100 * total_score / len(predictions), 2)
76+
77+
if __name__ == '__main__':
78+
args = parse_args()
79+
scores = dict()
80+
if args.e:
81+
path = f"pred_e/{args.model}/"
82+
else:
83+
path = f"pred_e/{args.model}/"
84+
all_files = os.listdir(path)
85+
print("Evaluating on:", all_files)
86+
for filename in all_files:
87+
if not filename.endswith("jsonl"):
88+
continue
89+
predictions, answers, lengths = [], [], []
90+
dataset = filename.split('.')[0]
91+
with open(f"{path}{filename}", "r", encoding="utf-8") as f:
92+
for line in f:
93+
data = json.loads(line)
94+
predictions.append(data["pred"])
95+
answers.append(data["answers"])
96+
all_classes = data["all_classes"]
97+
if "length" in data:
98+
lengths.append(data["length"])
99+
if args.e:
100+
score = scorer_e(dataset, predictions, answers, lengths, all_classes)
101+
else:
102+
score = scorer(dataset, predictions, answers, all_classes)
103+
if dataset == 'qasper':
104+
score_e = scorer_e(dataset, predictions, answers, lengths, all_classes)
105+
scores[dataset] = score
106+
# if dataset == 'qasper':
107+
# scores[dataset + '_e'] = score_e
108+
if args.e:
109+
out_path = f"H2O/results/{args.model}/result.json"
110+
else:
111+
out_path = f"H2O/results/{args.model}/result.json"
112+
# out_path_e = f"pred/{args.model}/result_e.json"
113+
# with open(out_path_e, "w") as f:
114+
# json.dump(score_e, f, ensure_ascii=False, indent=4)
115+
with open(out_path, "w") as f:
116+
json.dump(scores, f, ensure_ascii=False, indent=4)

experiments/LongBench/metrics.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import re
2+
import string
3+
4+
import jieba
5+
from fuzzywuzzy import fuzz
6+
import difflib
7+
8+
from typing import List
9+
from collections import Counter
10+
from rouge import Rouge
11+
12+
def normalize_answer(s):
13+
"""Lower text and remove punctuation, articles and extra whitespace."""
14+
15+
def remove_articles(text):
16+
return re.sub(r"\b(a|an|the)\b", " ", text)
17+
18+
def white_space_fix(text):
19+
return " ".join(text.split())
20+
21+
def remove_punc(text):
22+
exclude = set(string.punctuation)
23+
return "".join(ch for ch in text if ch not in exclude)
24+
25+
def lower(text):
26+
return text.lower()
27+
28+
return white_space_fix(remove_articles(remove_punc(lower(s))))
29+
30+
31+
def normalize_zh_answer(s):
32+
"""Lower text and remove punctuation, extra whitespace."""
33+
34+
def white_space_fix(text):
35+
return "".join(text.split())
36+
37+
def remove_punc(text):
38+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39+
all_punctuation = set(string.punctuation + cn_punctuation)
40+
return "".join(ch for ch in text if ch not in all_punctuation)
41+
42+
def lower(text):
43+
return text.lower()
44+
45+
return white_space_fix(remove_punc(lower(s)))
46+
47+
def count_score(prediction, ground_truth, **kwargs):
48+
numbers = re.findall(r"\d+", prediction)
49+
right_num = 0
50+
for number in numbers:
51+
if str(number) == str(ground_truth):
52+
right_num += 1
53+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54+
return float(final_score)
55+
56+
def retrieval_score(prediction, ground_truth, **kwargs):
57+
pattern = r'Paragraph (\d+)'
58+
matches = re.findall(pattern, ground_truth)
59+
ground_truth_id = matches[0]
60+
numbers = re.findall(r"\d+", prediction)
61+
right_num = 0
62+
for number in numbers:
63+
if str(number) == str(ground_truth_id):
64+
right_num += 1
65+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66+
return float(final_score)
67+
68+
def retrieval_zh_score(prediction, ground_truth, **kwargs):
69+
pattern = r'段落(\d+)'
70+
matches = re.findall(pattern, ground_truth)
71+
ground_truth_id = matches[0]
72+
numbers = re.findall(r"\d+", prediction)
73+
right_num = 0
74+
for number in numbers:
75+
if str(number) == str(ground_truth_id):
76+
right_num += 1
77+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78+
return float(final_score)
79+
80+
def code_sim_score(prediction, ground_truth, **kwargs):
81+
all_lines = prediction.lstrip('\n').split('\n')
82+
prediction = ""
83+
for line in all_lines:
84+
if ('`' not in line) and ('#' not in line) and ('//' not in line):
85+
prediction = line
86+
break
87+
return (fuzz.ratio(prediction, ground_truth) / 100)
88+
89+
def classification_score(prediction, ground_truth, **kwargs):
90+
em_match_list = []
91+
all_classes = kwargs["all_classes"]
92+
for class_name in all_classes:
93+
if class_name in prediction:
94+
em_match_list.append(class_name)
95+
for match_term in em_match_list:
96+
if match_term in ground_truth and match_term != ground_truth:
97+
em_match_list.remove(match_term)
98+
if ground_truth in em_match_list:
99+
score = (1.0 / len(em_match_list))
100+
else:
101+
score = 0.0
102+
return score
103+
104+
def rouge_score(prediction, ground_truth, **kwargs):
105+
rouge = Rouge()
106+
try:
107+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
108+
except:
109+
return 0.0
110+
return scores["rouge-l"]["f"]
111+
112+
def rouge_zh_score(prediction, ground_truth, **kwargs):
113+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
114+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
115+
score = rouge_score(prediction, ground_truth)
116+
return score
117+
118+
def f1_score(prediction, ground_truth, **kwargs):
119+
common = Counter(prediction) & Counter(ground_truth)
120+
num_same = sum(common.values())
121+
if num_same == 0:
122+
return 0
123+
precision = 1.0 * num_same / len(prediction)
124+
recall = 1.0 * num_same / len(ground_truth)
125+
f1 = (2 * precision * recall) / (precision + recall)
126+
return f1
127+
128+
def qa_f1_score(prediction, ground_truth, **kwargs):
129+
normalized_prediction = normalize_answer(prediction)
130+
normalized_ground_truth = normalize_answer(ground_truth)
131+
132+
prediction_tokens = normalized_prediction.split()
133+
ground_truth_tokens = normalized_ground_truth.split()
134+
return f1_score(prediction_tokens, ground_truth_tokens)
135+
136+
137+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
138+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
139+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
140+
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
141+
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
142+
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
143+
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
144+
return f1_score(prediction_tokens, ground_truth_tokens)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# example
2+
CUDA_VISIBLE_DEVICES=6 python pred_snap.py --model mistral-7B-instruct-v0.2 --compress_args_path ablation_c4096_w32_k7_maxpool.json

scripts/run_longbench.sh

Whitespace-only changes.

0 commit comments

Comments
 (0)