22
33import os
44import json
5- from dotenv import load_dotenv
5+ from dataclasses import dataclass
6+ from typing import List
67import argparse
78import asyncio
8- from dataclasses import dataclass
9+ from tqdm .asyncio import tqdm as tqdm_async
10+ from dotenv import load_dotenv
11+
912from models import OpenAIModel
10- from typing import List
1113from utils import create_event_loop , compute_content_hash
12- from tqdm .asyncio import tqdm as tqdm_async
1314
1415INSTRUCTION_GENERATION_PROMPT = '''The background knowledge is:
1516{doc}
@@ -49,6 +50,7 @@ def _post_process_answers(content: str) -> tuple:
4950 question = content .split ('Question:' )[1 ].split ('Answer:' )[0 ].strip ()
5051 answer = content .split ('Answer:' )[1 ].strip ()
5152 return question , answer
53+ return None , None
5254
5355@dataclass
5456class SelfQA :
@@ -59,7 +61,7 @@ def generate(self, docs: List[List[dict]]) -> List[dict]:
5961 loop = create_event_loop ()
6062 return loop .run_until_complete (self .async_generate (docs ))
6163
62- async def async_generate (self , docs : List [List [dict ]]) -> List [ dict ] :
64+ async def async_generate (self , docs : List [List [dict ]]) -> dict :
6365 final_results = {}
6466 semaphore = asyncio .Semaphore (self .max_concurrent )
6567
@@ -71,20 +73,26 @@ async def process_chunk(content: str):
7173 instruction_questions = _post_process_instructions (response )
7274
7375 qas = []
74- for qa in tqdm_async (asyncio .as_completed ([self .llm_client .generate_answer (READING_COMPREHENSION_PROMPT .format (doc = content , question = question )) for question in instruction_questions ]), total = len (instruction_questions ), desc = "Generating QAs" ):
76+ for qa in tqdm_async (asyncio .as_completed ([
77+ self .llm_client .generate_answer (READING_COMPREHENSION_PROMPT .format (
78+ doc = content ,
79+ question = question
80+ )) for question in instruction_questions ]),
81+ total = len (instruction_questions ), desc = "Generating QAs" ):
7582 try :
7683 question , answer = _post_process_answers (await qa )
77- qas .append ({
78- compute_content_hash (question ): {
79- 'question' : question ,
80- 'answer' : answer
81- }
82- })
83- except Exception as e :
84+ if question and answer :
85+ qas .append ({
86+ compute_content_hash (question ): {
87+ 'question' : question ,
88+ 'answer' : answer
89+ }
90+ })
91+ except Exception as e : # pylint: disable=broad-except
8492 print (f"Error: { e } " )
8593 continue
8694 return qas
87- except Exception as e :
95+ except Exception as e : # pylint: disable=broad-except
8896 print (f"Error: { e } " )
8997 return []
9098
@@ -98,7 +106,7 @@ async def process_chunk(content: str):
98106 qas = await result
99107 for qa in qas :
100108 final_results .update (qa )
101- except Exception as e :
109+ except Exception as e : # pylint: disable=broad-except
102110 print (f"Error: { e } " )
103111 return final_results
104112
@@ -131,15 +139,15 @@ async def process_chunk(content: str):
131139 self_qa = SelfQA (llm_client = llm_client )
132140
133141 if args .data_type == 'raw' :
134- with open (args .input_file , "r" ) as f :
142+ with open (args .input_file , "r" , encoding = 'utf-8' ) as f :
135143 data = [json .loads (line ) for line in f ]
136144 data = [[chunk ] for chunk in data ]
137145 elif args .data_type == 'chunked' :
138- with open (args .input_file , "r" ) as f :
146+ with open (args .input_file , "r" , encoding = 'utf-8' ) as f :
139147 data = json .load (f )
140148
141149 results = self_qa .generate (data )
142150
143151 # Save results
144- with open (args .output_file , "w" ) as f :
152+ with open (args .output_file , "w" , encoding = 'utf-8' ) as f :
145153 json .dump (results , f , indent = 4 , ensure_ascii = False )
0 commit comments