@@ -35,9 +35,9 @@ def setup_dataset(self):
3535 (
3636 self ._task_names ,
3737 self ._task_to_id ,
38- train_dataset ,
39- _ ,
40- _ ,
38+ self . train_dataset ,
39+ self . dev_dataset ,
40+ self . test_dataset ,
4141 ) = maybe_filter_hf_dataset_by_task (
4242 dataset , self .config .task_name_field , self .config .finetune_task_name
4343 )
@@ -124,243 +124,22 @@ def expand_questions(examples, tokenizer):
124124 if self .tokenizer .chat_template is None :
125125 self .tokenizer .apply_chat_template = lambda x , ** kwargs : x [0 ]["content" ]
126126
127- if "split" in train_dataset .features :
128- self .train_dataset , self .dev_dataset , self .test_dataset = (
129- split_on_split_column (train_dataset )
130- )
131- self .train_dataset = self .train_dataset .map (
132- lambda examples : expand_questions (examples , self .tokenizer ),
133- batched = True ,
134- batch_size = 1000 ,
135- num_proc = 1 ,
136- remove_columns = train_dataset .column_names ,
137- )
127+ self .train_dataset = self .train_dataset .map (
128+ lambda examples : expand_questions (examples , self .tokenizer ),
129+ batched = True ,
130+ batch_size = 1000 ,
131+ num_proc = 1 ,
132+ remove_columns = self .train_dataset .column_names ,
133+ )
134+ if self .dev_dataset :
138135 self .dev_dataset = self .dev_dataset .map (
139136 lambda examples : expand_questions (examples , self .tokenizer ),
140137 batched = True ,
141138 batch_size = 1000 ,
142139 num_proc = 1 ,
143- remove_columns = train_dataset .column_names ,
140+ remove_columns = self . dev_dataset .column_names ,
144141 )
145- self .test_dataset = self .dev_dataset
146142 else :
147- train_dataset = train_dataset .map (
148- lambda examples : expand_questions (examples , self .tokenizer ),
149- batched = True ,
150- batch_size = 1000 ,
151- num_proc = 1 ,
152- remove_columns = train_dataset .column_names ,
153- )
154- self .train_dataset = self .dev_dataset = self .test_dataset = train_dataset
155-
156-
157- prompt_template_w_docs = """
158- --------------BEGIN CONTEXT--------------
159-
160- {documents}
161-
162- --------------END CONTEXT--------------
163-
164- {question_text}
165- {options}
166-
167- Please answer using the following format:
168- 0. Begin your answer with the phrase "The correct answer is".
169- 1. State the letter of the correct option (e.g., A, B, C, D).
170- 2. Follow the letter with a colon and the exact text of the option you chose.
171- 3. Make sure your answer is a single, concise sentence.
172-
173- For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
174- 'The correct answer is C: Acute bronchitis.'
175- """
176-
177- prompt_template_no_docs = """
178- {question_text}
179- {options}
180-
181- Please answer using the following format:
182- 1. Begin your answer with the phrase "The correct answer is".
183- 2. State the letter of the correct option (e.g., A, B, C, D).
184- 3. Follow the letter with a colon and the exact text of the option you chose.
185- 4. Make sure your answer is a single, concise sentence.
186-
187- For example, if the correct answer to a question is option C, and the text for C is 'Acute Bronchitis', your answer should be:
188- 'The correct answer is C: Acute bronchitis.'
189- """
190-
191- max_new_tokens = 50
192-
193-
194- @dataclass
195- class GenQualityDatasetConfig (DatasetConfig ):
196- task_name_field : str = "document_id"
197- task_source_field : str = "document_id"
198- prompt : str = (
199- "Answer the following question. Give only the answer, and no extra commentary, formatting, or chattiness. Question: "
200- )
201- include_context : bool = False
202- topk_context : int = 10
203- include_all_answers : bool = True
204-
205-
206- @DataModule .register ("gen_quality" , config_cls = GenQualityDatasetConfig )
207- class GenQualityDataModule (DataModule ):
208- def setup_dataset (self ):
209- from mttl .models .library .dataset_library import DatasetLibrary
210-
211- dataset = DatasetLibrary .pull_dataset (self .config .dataset )
212-
213- # Instead of always working with the large datasets, we can subsample it
214- if self .config .custom_split_file :
215- dataset = apply_custom_split_file (dataset , self .config .custom_split_file )
216-
217- (
218- self ._task_names ,
219- self ._task_to_id ,
220- train_dataset ,
221- _ ,
222- _ ,
223- ) = maybe_filter_hf_dataset_by_task (
224- dataset , self .config .task_name_field , self .config .finetune_task_name
225- )
226-
227- # Let's make sure that the full prompt is always in context
228- len_template = len (self .tokenizer .encode (prompt_template_w_docs ))
229-
230- def expand_questions (examples , tokenizer , len_template ):
231- batch = {
232- "source" : [],
233- "target" : [],
234- "document_id" : [],
235- }
143+ self .dev_dataset = self .train_dataset
236144
237- for i in range (len (examples ["document_id" ])):
238- for j in range (len (examples ["questions" ][i ])):
239- document_id = examples ["document_id" ][i ]
240- question = examples ["questions" ][i ][j ]
241- options = examples ["options" ][i ][j ]
242- gold_label = examples ["gold_label" ][i ][j ]
243- if gold_label == - 1 :
244- gold_label = label_index = None
245- else :
246- label_index = gold_label - 1
247-
248- """ NEW """
249- letters = ["A" , "B" , "C" , "D" ]
250- option_str = "\n " .join (
251- [f"{ letters [i ]} : { option } " for i , option in enumerate (options )]
252- )
253- len_question = len (tokenizer .encode (question ))
254- len_options = len (tokenizer .encode (option_str ))
255- len_suffix = len (tokenizer .encode ("The correct answer is: " ))
256-
257- total_len = len_question + len_options + len_template + len_suffix
258-
259- if self .config .include_context :
260- context = examples ["text" ][i ]
261-
262- if isinstance (context , list ):
263- # following Alan's approach
264- context = " " .join (
265- [
266- f"Passage { k + 1 } : { context [k ]} \n \n "
267- for k in range (
268- min (self .config .topk_context , len (context ))
269- )[::- 1 ]
270- ]
271- )
272- assert (
273- type (context ) == str
274- ), f"Context should be a string, but got { type (context )} "
275-
276- # Let's do some rough trucation if needed
277- context_ids = tokenizer .encode (context )
278- len_context = len (context_ids )
279- space_left = self .config .max_input_length - total_len
280-
281- if space_left < len_context :
282- context_ids = context_ids [: max (0 , space_left - 20 )]
283- context = tokenizer .decode (
284- context_ids , skip_special_tokens = True
285- )
286-
287- prompt = prompt_template_w_docs .format (
288- documents = context ,
289- question_text = question ,
290- options = option_str ,
291- )
292- else :
293- prompt = prompt_template_no_docs .format (
294- question_text = question ,
295- options = option_str ,
296- )
297-
298- """
299- source = [
300- {
301- "role": "system",
302- "content": sys_prompt,
303- },
304- {
305- "role": "user",
306- "content": prompt,
307- },
308- ]
309- """
310- source = [
311- {
312- "role" : "user" ,
313- "content" : prompt ,
314- }
315- ]
316-
317- batch ["source" ].append (
318- tokenizer .apply_chat_template (
319- source , add_generation_prompt = True , tokenize = False
320- )
321- + "The correct answer is"
322- )
323- batch ["target" ].append (
324- letters [label_index ]
325- ) # [options[label_index]])
326- batch ["document_id" ].append (examples ["document_id" ][i ])
327-
328- return batch
329-
330- if self .tokenizer .chat_template is None :
331- self .tokenizer .apply_chat_template = lambda x , ** kwargs : x [0 ]["content" ]
332-
333- if "split" in train_dataset .features :
334- self .train_dataset , self .dev_dataset , self .test_dataset = (
335- split_on_split_column (train_dataset )
336- )
337- self .train_dataset = self .train_dataset .map (
338- lambda examples : expand_questions (
339- examples , self .tokenizer , len_template
340- ),
341- batched = True ,
342- batch_size = 1000 ,
343- num_proc = 1 ,
344- remove_columns = train_dataset .column_names ,
345- )
346- self .dev_dataset = self .dev_dataset .map (
347- lambda examples : expand_questions (
348- examples , self .tokenizer , len_template
349- ),
350- batched = True ,
351- batch_size = 1000 ,
352- num_proc = 1 ,
353- remove_columns = train_dataset .column_names ,
354- )
355- self .test_dataset = self .dev_dataset
356- else :
357- train_dataset = train_dataset .map (
358- lambda examples : expand_questions (
359- examples , self .tokenizer , len_template
360- ),
361- batched = True ,
362- batch_size = 1000 ,
363- num_proc = 1 ,
364- remove_columns = train_dataset .column_names ,
365- )
366- self .train_dataset = self .dev_dataset = self .test_dataset = train_dataset
145+ self .test_dataset = self .dev_dataset
0 commit comments