forked from IBM/mt-rag-benchmark
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconversations2retrieval.py
More file actions
110 lines (94 loc) · 4.26 KB
/
conversations2retrieval.py
File metadata and controls
110 lines (94 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# convert workbench file into BEIR json retrieval format:
# queries.jsonl
# {"_id": "0", "text": "What is considered a business expense on a business trip?", "metadata": {}}
# qrels/test.tsv
# query-id corpus-id score
# 8 566392 1
import json
import os
import pandas as pd
import argparse
def read_json(filename: str, encoding: str = "utf-8"):
with open(filename, mode="r", encoding=encoding) as fp:
return json.load(fp)
def write_json(filename: str, content: list | dict, encoding: str = "utf-8"):
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, mode="w", encoding=encoding) as fp:
return json.dump(content, fp)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
dest="input",
help="Path containing dataset to run",
)
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
dest="output_dir",
help="Path containing dataset to run",
)
parser.add_argument(
"-t",
"--turns_to_keep",
type=int,
default=-1,
dest="turns_to_keep",
help="Which turns to keep. It includs user + agent so: -1 is the last turn, -3 is the current question + previous q+a, 0 is full conversation",
)
parser.add_argument(
"-q",
"--q_only",
action="store_true",
dest="q_only",
help="Only use the questions and not the responses",
)
return parser.parse_args()
if __name__ == "__main__":
# Step 1: Read command line arguments, environment variables and runtime configuration
args = parse_args()
queries = {}
qrels = {}
wb_conversations = read_json(args.input)
wb_index = -1
# Step 2: iterate through conversations
for wb_conversation in wb_conversations:
wb_index += 1
conversation = []
collection_name = wb_conversation['retriever']['collection']['name']
if collection_name not in queries:
queries[collection_name] = []
qrels[collection_name] = []
_id = ""
rewrite = True
m_index = -1
# Step 3: iterate through messages in conversation. Each message is a turn
for message in wb_conversation['messages']:
m_index += 1
# Step 4: track turns
if not args.q_only or (args.q_only and message['speaker'] == 'user'):
conversation.append(f"|{message['speaker']}|: {message['text']}")
if message['speaker'] == 'user':
_id = f"{wb_conversation['author']}_{message['timestamp']}"
# Step 4a: apply turn logic
queries[collection_name].append({"_id": f"{_id}", "text": '\n'.join(conversation[args.turns_to_keep:])})
# Step 5: check responses for unanswerables (these turns are skipped, but the turn will still be part of other tasks)
else:
unanswerable = True
for context in message['contexts']:
if 'feedback' in context and (('editor' in wb_conversation and wb_conversation['editor'] in context['feedback']['relevant'] and context['feedback']['relevant'][wb_conversation['editor']]['value'] == 'yes') \
or (('editor' not in wb_conversation or wb_conversation['editor'] not in context['feedback']['relevant']) and context['feedback']['relevant'][wb_conversation['author']]['value'] == 'yes')):
qrels[collection_name].append({'query-id': f"{_id}", 'corpus-id': context['document_id'], 'score': 1})
unanswerable = False
if unanswerable:
del queries[collection_name][-1]
# Step 6: save files in BEIR retrieval format
for collection_name in queries:
os.makedirs(f"{args.output_dir}/retrieval/{collection_name}/qrels/", exist_ok=True)
pd.DataFrame(queries[collection_name]).to_json(f"{args.output_dir}/retrieval/{collection_name}/queries_turns{args.turns_to_keep}_qonly{args.q_only}.jsonl",lines=True, orient='records')
pd.DataFrame(qrels[collection_name]).to_csv(f"{args.output_dir}/retrieval/{collection_name}/qrels/dev.tsv", index=False, sep='\t')