Skip to content

Commit f912837

Browse files
Dataset generation and filtering, and import from other datasets (#21)
* Add typo generation * Make FAQ title optional * Generate typos in FAQs * Add entry variants * Refactor FaqConfig & scale typos per word * Add per-word multiplier * Clean-up * Filter short questions * Only do 2 epochs * Fix fine-tuning & change parameters * Use AnglELoss * Rename "generate" & implement Wiki QA
1 parent f7e4445 commit f912837

File tree

3 files changed

+278
-102
lines changed

3 files changed

+278
-102
lines changed

bingus-python-encoder/data_utils.py

Lines changed: 215 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,25 @@
1+
import math
12
import os
3+
from typing import TypeAlias
24
from pydantic import BaseModel
3-
from datasets import Dataset
5+
from datasets import Dataset, load_dataset
6+
from typo import StrErrer
7+
from random import Random
48

9+
RandomSeed: TypeAlias = int | float | str | bytes | bytearray | None
510

6-
class FaqEntry(BaseModel):
7-
title: str
8-
answer: str
9-
matched_questions: list[str]
1011

11-
12-
class FaqConfig(BaseModel):
13-
faqs: list[FaqEntry]
14-
15-
16-
def load_faq_config(paths: list[str]) -> FaqConfig:
17-
"""
18-
Searches through a list of paths to find and load the first existing faq_config.json file.
19-
Raises a FileNotFoundError if none of the paths exist.
20-
"""
21-
for path in paths:
22-
if os.path.isfile(path):
23-
print(f"Found \"faq_config.json\" at \"{path}\"!")
24-
with open(path, "r") as f:
25-
return FaqConfig.model_validate_json(f.read())
26-
raise FileNotFoundError(
27-
"Could not find \"faq_config.json\" in any of the default paths.")
12+
def split_dataset(dataset: Dataset, eval_percent: float | int) -> tuple[Dataset, Dataset | None]:
13+
"""Splits the dataset into training and evaluation sets based on the evaluation percentage."""
14+
if eval_percent > 0:
15+
split = dataset.train_test_split(test_size=eval_percent)
16+
return split["train"], split["test"]
17+
return dataset, None
2818

2919

30-
def generate_entry_pairs(entries: list[list[str]]) -> Dataset:
20+
def make_entry_pairs(entries: list[list[str]]) -> Dataset:
3121
"""
32-
Generates item-to-item pairs from the entry list, where each item is paired with all
22+
Makes item-to-item pairs from the entry list, where each item is paired with all
3323
other item in its set (positive samples) and from other sets (negative sample).
3424
"""
3525
items1, items2, scores = [], [], []
@@ -56,69 +46,217 @@ def generate_entry_pairs(entries: list[list[str]]) -> Dataset:
5646
})
5747

5848

59-
def generate_question_pairs(faqs: list[FaqEntry]) -> Dataset:
60-
"""
61-
Generates question-to-question pairs from the FAQs, where each question is paired with all
62-
other questions in its set (positive samples) and from other sets (negative sample).
63-
"""
64-
return generate_entry_pairs([faq.matched_questions for faq in faqs])
49+
def random_typo(str_err: StrErrer, random: Random) -> StrErrer:
50+
"""Applies a random typo to a string."""
51+
typo_type = random.randint(0, 7)
52+
if typo_type == 0:
53+
return str_err.char_swap()
54+
if typo_type == 1:
55+
return str_err.missing_char()
56+
if typo_type == 2:
57+
return str_err.extra_char()
58+
if typo_type == 3:
59+
return str_err.nearby_char()
60+
if typo_type == 4:
61+
return str_err.skipped_space()
62+
if typo_type == 5:
63+
return str_err.random_space()
64+
if typo_type == 6:
65+
return str_err.repeated_char()
66+
return str_err.unichar()
6567

6668

67-
def generate_question_answer_pairs(faqs: list[FaqEntry], include_title: bool = True) -> Dataset:
68-
"""
69-
Generates question-answer pairs from the FAQs, where each question is paired with its correct
70-
answer (positive sample) and other incorrect answers (negative samples).
71-
"""
69+
class FaqEntry(BaseModel):
70+
title: str | None
71+
answer: str
72+
matched_questions: list[str]
73+
74+
75+
class FaqConfig(BaseModel):
76+
faqs: list[FaqEntry]
77+
78+
@staticmethod
79+
def load_from_file(paths: list[str] | str):
80+
"""
81+
Searches through a list of paths to find and load the first existing faq_config.json file.
82+
Raises a FileNotFoundError if none of the paths exist.
83+
"""
84+
for path in paths:
85+
if os.path.isfile(path):
86+
print(f"Found \"faq_config.json\" at \"{path}\"!")
87+
with open(path, "r") as f:
88+
return FaqConfig.model_validate_json(f.read())
89+
raise FileNotFoundError(
90+
"Could not find \"faq_config.json\" in any of the default paths.")
91+
92+
def save_to_file(self, path: str):
93+
"""
94+
Saves a faq_config.json file to the specified path.
95+
"""
96+
with open(path, "w") as f:
97+
f.write(self.model_dump_json())
98+
99+
def iterate_answers(self):
100+
for faq in self.faqs:
101+
yield faq.answer
102+
103+
def iterate_questions(self):
104+
for faq in self.faqs:
105+
for question in faq.matched_questions:
106+
yield question
107+
108+
def question_count(self):
109+
return sum((len(faq.matched_questions) for faq in self.faqs))
110+
111+
def filter_short_questions(self, min_words: int):
112+
"""
113+
Filters out questions shorter than min_words and removes empty entries.
114+
"""
115+
for faq in self.faqs:
116+
faq.matched_questions = [
117+
q for q in faq.matched_questions if len(q.split()) >= min_words]
118+
self.faqs = [faq for faq in self.faqs if len(
119+
faq.matched_questions) > 0]
120+
121+
def make_typos(
122+
self,
123+
entry_variants: int,
124+
min_typos: int,
125+
max_typos: int,
126+
scale_max_per_word: bool = True,
127+
scale_min_per_word: bool = False,
128+
per_word_multiplier: float = 1.0,
129+
seed: RandomSeed = None
130+
) -> tuple[int, int]:
131+
"""
132+
Makes typos for each question of each entry and returns the number of entries added and the
133+
number of typos made.
134+
"""
135+
if entry_variants < 1:
136+
raise ValueError(
137+
"entry_variants must be greater than or equal to 1")
138+
if min_typos < 0:
139+
raise ValueError("min_typos must be greater than or equal to 0")
140+
if max_typos < 1:
141+
raise ValueError("max_typos must be greater than or equal to 1")
142+
if min_typos > max_typos:
143+
raise ValueError(
144+
"min_typos must be less than or equal to max_typos")
145+
146+
seeded_random = Random(seed)
147+
typo_entry_count = 0
148+
typo_count = 0
149+
for faq in self.faqs:
150+
new_qs: list[str] = []
151+
152+
for question in faq.matched_questions:
153+
q_min_typos = min_typos
154+
q_max_typos = max_typos
155+
if scale_max_per_word:
156+
num_words = max(1, len(question.split())
157+
* per_word_multiplier)
158+
q_max_typos *= num_words
159+
if scale_min_per_word:
160+
q_min_typos *= num_words
161+
162+
for _ in range(entry_variants):
163+
num_typos = seeded_random.randint(
164+
math.ceil(q_min_typos), math.ceil(q_max_typos))
165+
typo_q = StrErrer(question, seed=seeded_random.random())
166+
for _ in range(num_typos):
167+
typo_q = random_typo(typo_q, seeded_random)
168+
new_qs.append(typo_q.result)
169+
typo_count += num_typos
170+
171+
faq.matched_questions.extend(new_qs)
172+
typo_entry_count += len(new_qs)
173+
174+
return typo_entry_count, typo_count
175+
176+
def make_question_pairs(self) -> Dataset:
177+
"""
178+
Makes question-to-question pairs from the FAQs, where each question is paired with all
179+
other questions in its set (positive samples) and from other sets (negative sample).
180+
"""
181+
return make_entry_pairs([faq.matched_questions for faq in self.faqs])
182+
183+
def make_question_answer_pairs(self) -> Dataset:
184+
"""
185+
Makes question-answer pairs from the FAQs, where each question is paired with its correct
186+
answer (positive sample) and other incorrect answers (negative samples).
187+
"""
188+
questions, answers, scores = [], [], []
189+
190+
for faq in self.faqs:
191+
for question in faq.matched_questions:
192+
# Positive sample (correct answer)
193+
questions.append(question)
194+
answers.append(faq.answer)
195+
scores.append(1.0)
196+
197+
# Negative samples (incorrect answers)
198+
for other_answer in self.iterate_answers():
199+
if other_answer != faq.answer:
200+
questions.append(question)
201+
answers.append(other_answer)
202+
scores.append(0.0)
203+
204+
return Dataset.from_dict({
205+
"sentence1": questions,
206+
"sentence2": answers,
207+
"score": scores,
208+
})
209+
210+
def make_everything_pairs(self) -> Dataset:
211+
"""
212+
Makes pairs of titles, answers, and questions from the FAQs, where each set is paired with its correct
213+
answer (positive sample) and other incorrect answers (negative samples).
214+
"""
215+
return make_entry_pairs([[faq.title, faq.answer, *faq.matched_questions] for faq in self.faqs])
216+
217+
218+
def make_wiki_qa_dataset(faqs: FaqConfig, max_count: int = -1) -> Dataset:
72219
questions, answers, scores = [], [], []
73220

74-
# Precompute all answers for negative samples
75-
all_answers = [faq.answer for faq in faqs]
221+
def hit_max():
222+
return max_count > 0 and len(questions) >= max_count
223+
224+
wiki_qa = load_dataset("microsoft/wiki_qa")
225+
last_q_id = ""
226+
for row in wiki_qa["train"]:
227+
# Only process new questions
228+
q_id = row["question_id"]
229+
if last_q_id != q_id:
230+
last_q_id = q_id
231+
232+
# Negatively pair question with FAQ answers
233+
question = row["question"]
234+
for answer in faqs.iterate_answers():
235+
questions.append(question)
236+
answers.append(answer)
237+
scores.append(0.0)
238+
239+
if hit_max():
240+
break
76241

77-
for faq in faqs:
78-
for question in faq.matched_questions:
79-
# Positive sample (correct answer)
242+
if hit_max():
243+
break
244+
245+
# Negatively pair answer with FAQ questions
246+
answer = row["answer"]
247+
for question in faqs.iterate_questions():
80248
questions.append(question)
81-
answers.append(faq.answer)
82-
scores.append(1.0)
83-
84-
# Negative samples (incorrect answers)
85-
for other_answer in all_answers:
86-
if other_answer != faq.answer:
87-
questions.append(question)
88-
answers.append(other_answer)
89-
scores.append(0.0)
249+
answers.append(answer)
250+
scores.append(0.0)
90251

91-
if include_title:
92-
# Positive sample (correct answer)
93-
questions.append(faq.title)
94-
answers.append(faq.answer)
95-
scores.append(1.0)
96-
97-
# Negative samples (incorrect answers)
98-
for other_answer in all_answers:
99-
if other_answer != faq.answer:
100-
questions.append(faq.title)
101-
answers.append(other_answer)
102-
scores.append(0.0)
252+
if hit_max():
253+
break
254+
255+
if hit_max():
256+
break
103257

104258
return Dataset.from_dict({
105259
"sentence1": questions,
106260
"sentence2": answers,
107261
"score": scores,
108262
})
109-
110-
111-
def generate_everything_pairs(faqs: list[FaqEntry]) -> Dataset:
112-
"""
113-
Generates pairs of titles, answers, and questions from the FAQs, where each set is paired with its correct
114-
answer (positive sample) and other incorrect answers (negative samples).
115-
"""
116-
return generate_entry_pairs([[faq.title, faq.answer, *faq.matched_questions] for faq in faqs])
117-
118-
119-
def split_dataset(dataset: Dataset, eval_percent: float | int) -> tuple[Dataset, Dataset | None]:
120-
"""Splits the dataset into training and evaluation sets based on the evaluation percentage."""
121-
if eval_percent > 0:
122-
split = dataset.train_test_split(test_size=eval_percent)
123-
return split["train"], split["test"]
124-
return dataset, None

0 commit comments

Comments
 (0)