1- import asyncio
21from collections import defaultdict
32
4- from tqdm . asyncio import tqdm as tqdm_async
3+ import gradio as gr
54
65from graphgen .bases import BaseLLMWrapper
76from graphgen .models import JsonKVStorage , NetworkXStorage , QuizGenerator
8- from graphgen .utils import logger
7+ from graphgen .utils import logger , run_concurrent
98
109
1110async def quiz (
1211 synth_llm_client : BaseLLMWrapper ,
1312 graph_storage : NetworkXStorage ,
1413 rephrase_storage : JsonKVStorage ,
1514 max_samples : int = 1 ,
16- max_concurrent : int = 1000 ,
15+ progress_bar : gr . Progress = None ,
1716) -> JsonKVStorage :
1817 """
1918 Get all edges and quiz them using QuizGenerator.
@@ -22,37 +21,36 @@ async def quiz(
2221 :param graph_storage: graph storage instance
2322 :param rephrase_storage: rephrase storage instance
2423 :param max_samples: max samples for each edge
25- :param max_concurrent: max concurrent
24+ :param progress_bar
2625 :return:
2726 """
2827
29- semaphore = asyncio .Semaphore (max_concurrent )
3028 generator = QuizGenerator (synth_llm_client )
3129
32- async def _process_single_quiz (description : str , template_type : str , gt : str ):
33- async with semaphore :
34- try :
35- # if rephrase_storage exists already, directly get it
36- descriptions = await rephrase_storage .get_by_id (description )
37- if descriptions :
38- return None
39-
40- prompt = generator .build_prompt_for_description (description , template_type )
41- new_description = await synth_llm_client .generate_answer (
42- prompt , temperature = 1
43- )
44- rephrased_text = generator .parse_rephrased_text (new_description )
45- return {description : [(rephrased_text , gt )]}
46-
47- except Exception as e : # pylint: disable=broad-except
48- logger .error ("Error when quizzing description %s: %s" , description , e )
30+ async def _process_single_quiz (item : tuple [str , str , str ]):
31+ description , template_type , gt = item
32+ try :
33+ # if rephrase_storage exists already, directly get it
34+ descriptions = await rephrase_storage .get_by_id (description )
35+ if descriptions :
4936 return None
5037
38+ prompt = generator .build_prompt_for_description (description , template_type )
39+ new_description = await synth_llm_client .generate_answer (
40+ prompt , temperature = 1
41+ )
42+ rephrased_text = generator .parse_rephrased_text (new_description )
43+ return {description : [(rephrased_text , gt )]}
44+
45+ except Exception as e : # pylint: disable=broad-except
46+ logger .error ("Error when quizzing description %s: %s" , description , e )
47+ return None
48+
5149 edges = await graph_storage .get_all_edges ()
5250 nodes = await graph_storage .get_all_nodes ()
5351
5452 results = defaultdict (list )
55- tasks = []
53+ items = []
5654 for edge in edges :
5755 edge_data = edge [2 ]
5856 description = edge_data ["description" ]
@@ -61,12 +59,8 @@ async def _process_single_quiz(description: str, template_type: str, gt: str):
6159
6260 for i in range (max_samples ):
6361 if i > 0 :
64- tasks .append (
65- _process_single_quiz (description , "TEMPLATE" , "yes" )
66- )
67- tasks .append (
68- _process_single_quiz (description , "ANTI_TEMPLATE" , "no" )
69- )
62+ items .append ((description , "TEMPLATE" , "yes" ))
63+ items .append ((description , "ANTI_TEMPLATE" , "no" ))
7064
7165 for node in nodes :
7266 node_data = node [1 ]
@@ -76,17 +70,18 @@ async def _process_single_quiz(description: str, template_type: str, gt: str):
7670
7771 for i in range (max_samples ):
7872 if i > 0 :
79- tasks .append (
80- _process_single_quiz (description , "TEMPLATE" , "yes" )
81- )
82- tasks .append (
83- _process_single_quiz (description , "ANTI_TEMPLATE" , "no" )
84- )
85-
86- for result in tqdm_async (
87- asyncio .as_completed (tasks ), total = len (tasks ), desc = "Quizzing descriptions"
88- ):
89- new_result = await result
73+ items .append ((description , "TEMPLATE" , "yes" ))
74+ items .append ((description , "ANTI_TEMPLATE" , "no" ))
75+
76+ quiz_results = await run_concurrent (
77+ _process_single_quiz ,
78+ items ,
79+ desc = "Quizzing descriptions" ,
80+ unit = "description" ,
81+ progress_bar = progress_bar ,
82+ )
83+
84+ for new_result in quiz_results :
9085 if new_result :
9186 for key , value in new_result .items ():
9287 results [key ].extend (value )
0 commit comments