-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearch.py
More file actions
86 lines (67 loc) · 3.31 KB
/
search.py
File metadata and controls
86 lines (67 loc) · 3.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pickle
import json
from metric_eval import evaluate_page, evaluate_layout
from tqdm import tqdm
import argparse
def batch_dot_product(query_vec, passage_vecs):
return passage_vecs @ query_vec
def load_pickle(file_in):
# Load pickled files
with open(file_in, "rb") as fq:
return pickle.load(fq)
def initialize_args():
'''
Example: encode.py BGE --encode query,page,layout
'''
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str, help='Model name, e.g. BGE')
parser.add_argument('--encode_path', type=str, default='encode')
parser.add_argument('--encode', type=str, default="query,page,layout")
return parser.parse_args()
if __name__ == "__main__":
# ["BGE", "E5", "GTE", "Contriever", "DPR", "ColBERT"]
args = initialize_args()
model, encode, encode_path = args.model, args.encode, args.encode_path
if model.startswith("Col"):
from metric_eval import colbert_score, pad_tok_len
encoded_query, query_indices = load_pickle(f"{encode_path}/encoded_query_{model}.pkl")
if "page" in encode:
encoded_page, page_indices = load_pickle(f"{encode_path}/encoded_page_{model}.pkl")
if "layout" in encode:
encoded_layout, layout_indices = load_pickle(f"{encode_path}/encoded_layout_{model}.pkl")
gt_list = []
for line in open("dataset/MMDocIR_annotations.jsonl", 'r', encoding="utf-8"):
item = json.loads(line.strip())
for qa in item["questions"]:
qa["domain"] = item["domain"]
gt_list.append(qa)
if len(gt_list) != len(query_indices):
raise ValueError("number of indexed question do not match ground-truth")
# To do this for every query in query_indices:
for (query_id, start_pid, end_pid, start_lid, end_lid) in tqdm(query_indices):
query_vec = encoded_query[query_id]
if "page" in encode:
page_vecs = encoded_page[start_pid:end_pid + 1]
if not model.startswith("Col"):
scores_page = batch_dot_product(query_vec, page_vecs)
else:
page_vecs_pad, masks_page = pad_tok_len(page_vecs)
scores_page = colbert_score(query_vec, page_vecs_pad, masks_page)
gt_list[query_id]["scores_page"] = scores_page.tolist()
if "layout" in encode:
layout_vecs = encoded_layout[start_lid:end_lid + 1]
if not model.startswith("Col"):
scores_layout = batch_dot_product(query_vec, layout_vecs)
else:
layout_vecs_pad, masks_layout = pad_tok_len(layout_vecs)
scores_layout = colbert_score(query_vec, layout_vecs_pad, masks_layout, use_gpu=True)
gt_list[query_id]["scores_layout"] = scores_layout.tolist()
gt_list[query_id]["layout_indices"] = layout_indices[start_lid:end_lid + 1]
if "page" in encode:
evaluate_page(gt_list, model_name=model, topk=1, metric="recall")
evaluate_page(gt_list, model_name=model, topk=3, metric="recall")
evaluate_page(gt_list, model_name=model, topk=5, metric="recall")
if "layout" in encode:
evaluate_layout(gt_list, model_name=model, topk=1, metric="recall")
evaluate_layout(gt_list, model_name=model, topk=5, metric="recall")
evaluate_layout(gt_list, model_name=model, topk=10, metric="recall")