forked from EuroEval/EuroEval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_norglm_multiqa.py
More file actions
247 lines (199 loc) · 7.99 KB
/
create_norglm_multiqa.py
File metadata and controls
247 lines (199 loc) · 7.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# /// script
# requires-python = ">=3.10,<4.0"
# dependencies = [
# "datasets==2.15.0",
# "huggingface-hub==0.24.0",
# "openai==1.66.5",
# "pandas==2.2.0",
# "python-dotenv==1.0.1",
# "requests==2.32.3",
# ]
# ///
"""Create the NorGLM NO-multi question answering dataset."""
import ast
import hashlib
import os
import pandas as pd
from datasets import Dataset, DatasetDict, Split, load_dataset
from dotenv import load_dotenv
from huggingface_hub import HfApi
from openai import OpenAI
from .constants import (
MAX_NUM_CHARS_IN_CONTEXT,
MAX_NUM_CHARS_IN_QUESTION,
MIN_NUM_CHARS_IN_CONTEXT,
MIN_NUM_CHARS_IN_QUESTION,
)
load_dotenv()
def main() -> None:
"""Create the NorGLM NO-multi question answering dataset and upload to HF Hub."""
dataset_id = "NorGLM/NO-Multi-QA-Sum"
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
dataset = load_dataset(dataset_id, split="train", token=True)
assert isinstance(dataset, Dataset)
dataset = dataset.rename_columns(column_mapping=dict(article="context"))
df = dataset.to_pandas()
assert isinstance(df, pd.DataFrame)
# Drop unneeded index column
df.drop("Unnamed: 0", inplace=True, axis=1)
# Drop non-article by index
df.drop(index=359, inplace=True)
# Shuffle and drop first duplicate
df = df.sample(frac=1, random_state=4242).drop_duplicates(subset="context")
# Reset the index
df = df.reset_index(drop=True)
# Convert the question_answer column to a list of tuples
def qa_to_list(row: str) -> list:
"""Converts a string representation of a list of tuples to an actual list.
Args:
row:
The row to convert
Returns:
The row converted to an actual list of tuples
"""
qa_list = ast.literal_eval(row)
quest_ans = [(q.strip(), a.strip()) for q, a in qa_list]
return quest_ans
df["question_answer"] = df["question_answer"].apply(qa_to_list)
# split question_answer list [(question, answer) ...] column to question and answer
# with same context
df = df.explode("question_answer")
df[["question", "answer"]] = pd.DataFrame(
df.question_answer.tolist(), index=df.index
)
df.drop(["question_answer", "summary"], inplace=True, axis=1)
df.reset_index(drop=True, inplace=True)
# Only work with samples where the context is not very large or small
lengths = df.context.str.len()
lower_bound = MIN_NUM_CHARS_IN_CONTEXT
upper_bound = MAX_NUM_CHARS_IN_CONTEXT
df = df[lengths.between(lower_bound, upper_bound)]
# Only work with samples where the question is not very large or small
lengths = df.question.str.len()
lower_bound = MIN_NUM_CHARS_IN_QUESTION
upper_bound = MAX_NUM_CHARS_IN_QUESTION
df = df[lengths.between(lower_bound, upper_bound)]
assert isinstance(df, pd.DataFrame)
def rephrase_answer(question: str, answer: str, context: str) -> str:
"""Rephrase the answer such that it is in the context.
Args:
question:
The question.
answer:
The answer.
context:
The context.
Returns:
The rephrased answer (if the answer is already in the context, it is
returned as is).
"""
answer = answer[:-1] if answer.endswith(".") else answer
if answer.lower() in context.lower():
return answer
# Use OpenAI to rephrase the answer
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": (
f"Given the following context: '{context}', and the "
f"'{question}' please rephrase the answer, so that it matches "
"exactly with a phrase from the context in a case-insensitive "
"way. E.g. it must be possible to find the rephrased answer "
"in the context without any modifications. The rephrased "
"answer should be as concise as possible, preferable not "
"more than 7 words, and ideally 3 or less words. Now "
f"rephrase this answer: '{answer}'"
),
}
],
model="gpt-4o",
seed=4242,
temperature=0,
)
rephrased_answer = chat_completion.choices[0].message.content
return rephrased_answer or answer
# Use the original answer if it could be found in the context
df_orig_bool = df[["answer", "context"]].apply(
lambda row: row["answer"].lower() in row["context"].lower(), axis=1
)
df_orig = df[df_orig_bool]
# For the rest of the answers, we try to rephrase them
df_no_context = df[~df_orig_bool]
# Rephrase the answers
df_no_context.loc[:, "answer"] = df_no_context.apply(
lambda row: rephrase_answer(row["question"], row["answer"], row["context"]),
axis=1,
)
assert isinstance(df_no_context, pd.DataFrame)
# Remove non-word start and end characters from answers
df_no_context.loc[:, "answer"] = df_no_context["answer"].str.replace(
r"(^\W|\W$)", "", regex=True
)
# Only keep the rephrased answers where the answer could be found in the context
df_in_context = df_no_context[["answer", "context"]].apply(
lambda row: row["answer"].lower() in row["context"].lower(), axis=1
)
df_with_answer = df_no_context[df_in_context]
# Combine original with the rephrased answers
cleaned_df = pd.concat([df_orig, df_with_answer])
cleaned_df = cleaned_df.reset_index(drop=True)
# Convert to 'answers' style column
cleaned_df.loc[:, "answers"] = cleaned_df.apply(
lambda row: {
"text": [row["answer"]],
"answer_start": row["context"].lower().index(row["answer"].lower()),
},
axis=1,
)
# Overwrite the original dataframe with the cleaned version
df = cleaned_df
# should not change since we use seed and temp 0 for rephrasing
assert len(df) == 2406
# Create validation split
val_size = 256
val_df = df.sample(n=val_size, random_state=4242)
# Create train split
train_size = 1024
filtered_df = df[~df.index.isin(val_df.index)]
assert isinstance(filtered_df, pd.DataFrame)
train_df = filtered_df.sample(n=train_size, random_state=4242)
# Create test split, using the remaining samples
test_df = filtered_df[~filtered_df.index.isin(train_df.index)]
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
train_df = train_df.reset_index(drop=True)
# Add ID column
train_df["id"] = [
hashlib.md5((row.context + row.question).encode("utf-8")).hexdigest()
for _, row in train_df.iterrows()
]
val_df["id"] = [
hashlib.md5((row.context + row.question).encode("utf-8")).hexdigest()
for _, row in val_df.iterrows()
]
test_df["id"] = [
hashlib.md5((row.context + row.question).encode("utf-8")).hexdigest()
for _, row in test_df.iterrows()
]
assert isinstance(train_df, pd.DataFrame)
assert isinstance(val_df, pd.DataFrame)
assert isinstance(test_df, pd.DataFrame)
# Check that the IDs are unique
assert train_df.id.nunique() == len(train_df)
assert val_df.id.nunique() == len(val_df)
assert test_df.id.nunique() == len(test_df)
# Collect datasets in a dataset dictionary
dataset = DatasetDict(
{
"train": Dataset.from_pandas(train_df, split=Split.TRAIN),
"val": Dataset.from_pandas(val_df, split=Split.VALIDATION),
"test": Dataset.from_pandas(test_df, split=Split.TEST),
}
)
# Push the dataset to the Hugging Face Hub
dataset_id = "EuroEval/norglm-multi-qa"
HfApi().delete_repo(dataset_id, repo_type="dataset", missing_ok=True)
dataset.push_to_hub(dataset_id, private=True)
if __name__ == "__main__":
main()