Skip to content

Commit ccaf863

Browse files
committed
refactor: use run_concurrent in quiz operator and fix language codes
- Replace manual asyncio with run_concurrent utility - Fix language codes: English/Chinese -> en/zh - Add progress_bar support
1 parent d10e301 commit ccaf863

File tree

5 files changed

+43
-45
lines changed

5 files changed

+43
-45
lines changed

graphgen/graphgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
221221
self.graph_storage,
222222
self.rephrase_storage,
223223
max_samples,
224+
progress_bar=self.progress_bar,
224225
)
225226

226227
# TODO: assert trainee_llm_client is valid before judge

graphgen/models/generator/quiz_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def build_prompt_for_description(description: str, template_type: str = "TEMPLAT
4343
:param template_type: Either "TEMPLATE" (same meaning) or "ANTI_TEMPLATE" (opposite meaning)
4444
:return: Prompt string
4545
"""
46-
language = "English" if detect_main_language(description) == "en" else "Chinese"
46+
language = "en" if detect_main_language(description) == "en" else "zh"
4747
prompt = DESCRIPTION_REPHRASING_PROMPT[language][template_type].format(
4848
input_sentence=description
4949
)

graphgen/operators/generate/generate_qas.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any
22

3+
import gradio as gr
4+
35
from graphgen.bases import BaseLLMWrapper
46
from graphgen.models import (
57
AggregatedGenerator,
@@ -19,7 +21,7 @@ async def generate_qas(
1921
]
2022
],
2123
generation_config: dict,
22-
progress_bar=None,
24+
progress_bar: gr.Progress = None,
2325
) -> list[dict[str, Any]]:
2426
"""
2527
Generate question-answer pairs based on nodes and edges.

graphgen/operators/quiz_and_judge/quiz.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
1-
import asyncio
21
from collections import defaultdict
32

4-
from tqdm.asyncio import tqdm as tqdm_async
3+
import gradio as gr
54

65
from graphgen.bases import BaseLLMWrapper
76
from graphgen.models import JsonKVStorage, NetworkXStorage, QuizGenerator
8-
from graphgen.utils import logger
7+
from graphgen.utils import logger, run_concurrent
98

109

1110
async def quiz(
1211
synth_llm_client: BaseLLMWrapper,
1312
graph_storage: NetworkXStorage,
1413
rephrase_storage: JsonKVStorage,
1514
max_samples: int = 1,
16-
max_concurrent: int = 1000,
15+
progress_bar: gr.Progress = None,
1716
) -> JsonKVStorage:
1817
"""
1918
Get all edges and quiz them using QuizGenerator.
@@ -22,37 +21,36 @@ async def quiz(
2221
:param graph_storage: graph storage instance
2322
:param rephrase_storage: rephrase storage instance
2423
:param max_samples: max samples for each edge
25-
:param max_concurrent: max concurrent
24+
:param progress_bar
2625
:return:
2726
"""
2827

29-
semaphore = asyncio.Semaphore(max_concurrent)
3028
generator = QuizGenerator(synth_llm_client)
3129

32-
async def _process_single_quiz(description: str, template_type: str, gt: str):
33-
async with semaphore:
34-
try:
35-
# if rephrase_storage exists already, directly get it
36-
descriptions = await rephrase_storage.get_by_id(description)
37-
if descriptions:
38-
return None
39-
40-
prompt = generator.build_prompt_for_description(description, template_type)
41-
new_description = await synth_llm_client.generate_answer(
42-
prompt, temperature=1
43-
)
44-
rephrased_text = generator.parse_rephrased_text(new_description)
45-
return {description: [(rephrased_text, gt)]}
46-
47-
except Exception as e: # pylint: disable=broad-except
48-
logger.error("Error when quizzing description %s: %s", description, e)
30+
async def _process_single_quiz(item: tuple[str, str, str]):
31+
description, template_type, gt = item
32+
try:
33+
# if rephrase_storage exists already, directly get it
34+
descriptions = await rephrase_storage.get_by_id(description)
35+
if descriptions:
4936
return None
5037

38+
prompt = generator.build_prompt_for_description(description, template_type)
39+
new_description = await synth_llm_client.generate_answer(
40+
prompt, temperature=1
41+
)
42+
rephrased_text = generator.parse_rephrased_text(new_description)
43+
return {description: [(rephrased_text, gt)]}
44+
45+
except Exception as e: # pylint: disable=broad-except
46+
logger.error("Error when quizzing description %s: %s", description, e)
47+
return None
48+
5149
edges = await graph_storage.get_all_edges()
5250
nodes = await graph_storage.get_all_nodes()
5351

5452
results = defaultdict(list)
55-
tasks = []
53+
items = []
5654
for edge in edges:
5755
edge_data = edge[2]
5856
description = edge_data["description"]
@@ -61,12 +59,8 @@ async def _process_single_quiz(description: str, template_type: str, gt: str):
6159

6260
for i in range(max_samples):
6361
if i > 0:
64-
tasks.append(
65-
_process_single_quiz(description, "TEMPLATE", "yes")
66-
)
67-
tasks.append(
68-
_process_single_quiz(description, "ANTI_TEMPLATE", "no")
69-
)
62+
items.append((description, "TEMPLATE", "yes"))
63+
items.append((description, "ANTI_TEMPLATE", "no"))
7064

7165
for node in nodes:
7266
node_data = node[1]
@@ -76,17 +70,18 @@ async def _process_single_quiz(description: str, template_type: str, gt: str):
7670

7771
for i in range(max_samples):
7872
if i > 0:
79-
tasks.append(
80-
_process_single_quiz(description, "TEMPLATE", "yes")
81-
)
82-
tasks.append(
83-
_process_single_quiz(description, "ANTI_TEMPLATE", "no")
84-
)
85-
86-
for result in tqdm_async(
87-
asyncio.as_completed(tasks), total=len(tasks), desc="Quizzing descriptions"
88-
):
89-
new_result = await result
73+
items.append((description, "TEMPLATE", "yes"))
74+
items.append((description, "ANTI_TEMPLATE", "no"))
75+
76+
quiz_results = await run_concurrent(
77+
_process_single_quiz,
78+
items,
79+
desc="Quizzing descriptions",
80+
unit="description",
81+
progress_bar=progress_bar,
82+
)
83+
84+
for new_result in quiz_results:
9085
if new_result:
9186
for key, value in new_result.items():
9287
results[key].extend(value)

graphgen/templates/description_rephrasing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@
110110

111111

112112
DESCRIPTION_REPHRASING_PROMPT= {
113-
"English": {
113+
"en": {
114114
"ANTI_TEMPLATE": ANTI_TEMPLATE_EN,
115115
"TEMPLATE": TEMPLATE_EN
116116
},
117-
"Chinese": {
117+
"zh": {
118118
"ANTI_TEMPLATE": ANTI_TEMPLATE_ZH,
119119
"TEMPLATE": TEMPLATE_ZH
120120
}

0 commit comments

Comments
 (0)