1+ import math
12import os
3+ from typing import TypeAlias
24from pydantic import BaseModel
3- from datasets import Dataset
5+ from datasets import Dataset , load_dataset
6+ from typo import StrErrer
7+ from random import Random
48
9+ RandomSeed : TypeAlias = int | float | str | bytes | bytearray | None
510
6- class FaqEntry (BaseModel ):
7- title : str
8- answer : str
9- matched_questions : list [str ]
1011
11-
12- class FaqConfig (BaseModel ):
13- faqs : list [FaqEntry ]
14-
15-
16- def load_faq_config (paths : list [str ]) -> FaqConfig :
17- """
18- Searches through a list of paths to find and load the first existing faq_config.json file.
19- Raises a FileNotFoundError if none of the paths exist.
20- """
21- for path in paths :
22- if os .path .isfile (path ):
23- print (f"Found \" faq_config.json\" at \" { path } \" !" )
24- with open (path , "r" ) as f :
25- return FaqConfig .model_validate_json (f .read ())
26- raise FileNotFoundError (
27- "Could not find \" faq_config.json\" in any of the default paths." )
12+ def split_dataset (dataset : Dataset , eval_percent : float | int ) -> tuple [Dataset , Dataset | None ]:
13+ """Splits the dataset into training and evaluation sets based on the evaluation percentage."""
14+ if eval_percent > 0 :
15+ split = dataset .train_test_split (test_size = eval_percent )
16+ return split ["train" ], split ["test" ]
17+ return dataset , None
2818
2919
30- def generate_entry_pairs (entries : list [list [str ]]) -> Dataset :
20+ def make_entry_pairs (entries : list [list [str ]]) -> Dataset :
3121 """
32- Generates item-to-item pairs from the entry list, where each item is paired with all
22+ Makes item-to-item pairs from the entry list, where each item is paired with all
3323 other item in its set (positive samples) and from other sets (negative sample).
3424 """
3525 items1 , items2 , scores = [], [], []
@@ -56,69 +46,217 @@ def generate_entry_pairs(entries: list[list[str]]) -> Dataset:
5646 })
5747
5848
59- def generate_question_pairs (faqs : list [FaqEntry ]) -> Dataset :
60- """
61- Generates question-to-question pairs from the FAQs, where each question is paired with all
62- other questions in its set (positive samples) and from other sets (negative sample).
63- """
64- return generate_entry_pairs ([faq .matched_questions for faq in faqs ])
49+ def random_typo (str_err : StrErrer , random : Random ) -> StrErrer :
50+ """Applies a random typo to a string."""
51+ typo_type = random .randint (0 , 7 )
52+ if typo_type == 0 :
53+ return str_err .char_swap ()
54+ if typo_type == 1 :
55+ return str_err .missing_char ()
56+ if typo_type == 2 :
57+ return str_err .extra_char ()
58+ if typo_type == 3 :
59+ return str_err .nearby_char ()
60+ if typo_type == 4 :
61+ return str_err .skipped_space ()
62+ if typo_type == 5 :
63+ return str_err .random_space ()
64+ if typo_type == 6 :
65+ return str_err .repeated_char ()
66+ return str_err .unichar ()
6567
6668
67- def generate_question_answer_pairs (faqs : list [FaqEntry ], include_title : bool = True ) -> Dataset :
68- """
69- Generates question-answer pairs from the FAQs, where each question is paired with its correct
70- answer (positive sample) and other incorrect answers (negative samples).
71- """
69+ class FaqEntry (BaseModel ):
70+ title : str | None
71+ answer : str
72+ matched_questions : list [str ]
73+
74+
75+ class FaqConfig (BaseModel ):
76+ faqs : list [FaqEntry ]
77+
78+ @staticmethod
79+ def load_from_file (paths : list [str ] | str ):
80+ """
81+ Searches through a list of paths to find and load the first existing faq_config.json file.
82+ Raises a FileNotFoundError if none of the paths exist.
83+ """
84+ for path in paths :
85+ if os .path .isfile (path ):
86+ print (f"Found \" faq_config.json\" at \" { path } \" !" )
87+ with open (path , "r" ) as f :
88+ return FaqConfig .model_validate_json (f .read ())
89+ raise FileNotFoundError (
90+ "Could not find \" faq_config.json\" in any of the default paths." )
91+
92+ def save_to_file (self , path : str ):
93+ """
94+ Saves a faq_config.json file to the specified path.
95+ """
96+ with open (path , "w" ) as f :
97+ f .write (self .model_dump_json ())
98+
99+ def iterate_answers (self ):
100+ for faq in self .faqs :
101+ yield faq .answer
102+
103+ def iterate_questions (self ):
104+ for faq in self .faqs :
105+ for question in faq .matched_questions :
106+ yield question
107+
108+ def question_count (self ):
109+ return sum ((len (faq .matched_questions ) for faq in self .faqs ))
110+
111+ def filter_short_questions (self , min_words : int ):
112+ """
113+ Filters out questions shorter than min_words and removes empty entries.
114+ """
115+ for faq in self .faqs :
116+ faq .matched_questions = [
117+ q for q in faq .matched_questions if len (q .split ()) >= min_words ]
118+ self .faqs = [faq for faq in self .faqs if len (
119+ faq .matched_questions ) > 0 ]
120+
121+ def make_typos (
122+ self ,
123+ entry_variants : int ,
124+ min_typos : int ,
125+ max_typos : int ,
126+ scale_max_per_word : bool = True ,
127+ scale_min_per_word : bool = False ,
128+ per_word_multiplier : float = 1.0 ,
129+ seed : RandomSeed = None
130+ ) -> tuple [int , int ]:
131+ """
132+ Makes typos for each question of each entry and returns the number of entries added and the
133+ number of typos made.
134+ """
135+ if entry_variants < 1 :
136+ raise ValueError (
137+ "entry_variants must be greater than or equal to 1" )
138+ if min_typos < 0 :
139+ raise ValueError ("min_typos must be greater than or equal to 0" )
140+ if max_typos < 1 :
141+ raise ValueError ("max_typos must be greater than or equal to 1" )
142+ if min_typos > max_typos :
143+ raise ValueError (
144+ "min_typos must be less than or equal to max_typos" )
145+
146+ seeded_random = Random (seed )
147+ typo_entry_count = 0
148+ typo_count = 0
149+ for faq in self .faqs :
150+ new_qs : list [str ] = []
151+
152+ for question in faq .matched_questions :
153+ q_min_typos = min_typos
154+ q_max_typos = max_typos
155+ if scale_max_per_word :
156+ num_words = max (1 , len (question .split ())
157+ * per_word_multiplier )
158+ q_max_typos *= num_words
159+ if scale_min_per_word :
160+ q_min_typos *= num_words
161+
162+ for _ in range (entry_variants ):
163+ num_typos = seeded_random .randint (
164+ math .ceil (q_min_typos ), math .ceil (q_max_typos ))
165+ typo_q = StrErrer (question , seed = seeded_random .random ())
166+ for _ in range (num_typos ):
167+ typo_q = random_typo (typo_q , seeded_random )
168+ new_qs .append (typo_q .result )
169+ typo_count += num_typos
170+
171+ faq .matched_questions .extend (new_qs )
172+ typo_entry_count += len (new_qs )
173+
174+ return typo_entry_count , typo_count
175+
176+ def make_question_pairs (self ) -> Dataset :
177+ """
178+ Makes question-to-question pairs from the FAQs, where each question is paired with all
179+ other questions in its set (positive samples) and from other sets (negative sample).
180+ """
181+ return make_entry_pairs ([faq .matched_questions for faq in self .faqs ])
182+
183+ def make_question_answer_pairs (self ) -> Dataset :
184+ """
185+ Makes question-answer pairs from the FAQs, where each question is paired with its correct
186+ answer (positive sample) and other incorrect answers (negative samples).
187+ """
188+ questions , answers , scores = [], [], []
189+
190+ for faq in self .faqs :
191+ for question in faq .matched_questions :
192+ # Positive sample (correct answer)
193+ questions .append (question )
194+ answers .append (faq .answer )
195+ scores .append (1.0 )
196+
197+ # Negative samples (incorrect answers)
198+ for other_answer in self .iterate_answers ():
199+ if other_answer != faq .answer :
200+ questions .append (question )
201+ answers .append (other_answer )
202+ scores .append (0.0 )
203+
204+ return Dataset .from_dict ({
205+ "sentence1" : questions ,
206+ "sentence2" : answers ,
207+ "score" : scores ,
208+ })
209+
210+ def make_everything_pairs (self ) -> Dataset :
211+ """
212+ Makes pairs of titles, answers, and questions from the FAQs, where each set is paired with its correct
213+ answer (positive sample) and other incorrect answers (negative samples).
214+ """
215+ return make_entry_pairs ([[faq .title , faq .answer , * faq .matched_questions ] for faq in self .faqs ])
216+
217+
218+ def make_wiki_qa_dataset (faqs : FaqConfig , max_count : int = - 1 ) -> Dataset :
72219 questions , answers , scores = [], [], []
73220
74- # Precompute all answers for negative samples
75- all_answers = [faq .answer for faq in faqs ]
221+ def hit_max ():
222+ return max_count > 0 and len (questions ) >= max_count
223+
224+ wiki_qa = load_dataset ("microsoft/wiki_qa" )
225+ last_q_id = ""
226+ for row in wiki_qa ["train" ]:
227+ # Only process new questions
228+ q_id = row ["question_id" ]
229+ if last_q_id != q_id :
230+ last_q_id = q_id
231+
232+ # Negatively pair question with FAQ answers
233+ question = row ["question" ]
234+ for answer in faqs .iterate_answers ():
235+ questions .append (question )
236+ answers .append (answer )
237+ scores .append (0.0 )
238+
239+ if hit_max ():
240+ break
76241
77- for faq in faqs :
78- for question in faq .matched_questions :
79- # Positive sample (correct answer)
242+ if hit_max ():
243+ break
244+
245+ # Negatively pair answer with FAQ questions
246+ answer = row ["answer" ]
247+ for question in faqs .iterate_questions ():
80248 questions .append (question )
81- answers .append (faq .answer )
82- scores .append (1.0 )
83-
84- # Negative samples (incorrect answers)
85- for other_answer in all_answers :
86- if other_answer != faq .answer :
87- questions .append (question )
88- answers .append (other_answer )
89- scores .append (0.0 )
249+ answers .append (answer )
250+ scores .append (0.0 )
90251
91- if include_title :
92- # Positive sample (correct answer)
93- questions .append (faq .title )
94- answers .append (faq .answer )
95- scores .append (1.0 )
96-
97- # Negative samples (incorrect answers)
98- for other_answer in all_answers :
99- if other_answer != faq .answer :
100- questions .append (faq .title )
101- answers .append (other_answer )
102- scores .append (0.0 )
252+ if hit_max ():
253+ break
254+
255+ if hit_max ():
256+ break
103257
104258 return Dataset .from_dict ({
105259 "sentence1" : questions ,
106260 "sentence2" : answers ,
107261 "score" : scores ,
108262 })
109-
110-
111- def generate_everything_pairs (faqs : list [FaqEntry ]) -> Dataset :
112- """
113- Generates pairs of titles, answers, and questions from the FAQs, where each set is paired with its correct
114- answer (positive sample) and other incorrect answers (negative samples).
115- """
116- return generate_entry_pairs ([[faq .title , faq .answer , * faq .matched_questions ] for faq in faqs ])
117-
118-
119- def split_dataset (dataset : Dataset , eval_percent : float | int ) -> tuple [Dataset , Dataset | None ]:
120- """Splits the dataset into training and evaluation sets based on the evaluation percentage."""
121- if eval_percent > 0 :
122- split = dataset .train_test_split (test_size = eval_percent )
123- return split ["train" ], split ["test" ]
124- return dataset , None
0 commit comments