Skip to content

Commit 8980480

Browse files
fix(graphgen): optimize quizzing
1 parent 1636a28 commit 8980480

File tree

1 file changed

+46
-45
lines changed

1 file changed

+46
-45
lines changed
Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from collections import defaultdict
23

34
from tqdm.asyncio import tqdm as tqdm_async
45
from models import JsonKVStorage, OpenAIModel, NetworkXStorage
@@ -27,67 +28,67 @@ async def quiz_relations(
2728

2829
async def _quiz_single_relation(
2930
edge: tuple,
31+
des: str,
32+
prompt: str,
33+
gt: str
3034
):
3135
async with semaphore:
3236
source_id = edge[0]
3337
target_id = edge[1]
34-
edge_data = edge[2]
35-
36-
description = edge_data["description"]
37-
language = "English" if detect_main_language(description) == "en" else "Chinese"
3838

3939
try:
4040
# 如果在rephrase_storage中已经存在,直接取出
41-
descriptions = await rephrase_storage.get_by_id(description)
42-
if not descriptions:
43-
# 多次采样,取平均
44-
descriptions = [(description, 'yes')]
45-
46-
new_description_tasks = []
47-
new_anti_description_tasks = []
48-
for i in range(max_samples):
49-
if i > 0:
50-
new_description_tasks.append(
51-
teacher_llm_client.generate_answer(
52-
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
53-
input_sentence=description),
54-
temperature=1
55-
)
56-
)
57-
new_anti_description_tasks.append(
58-
teacher_llm_client.generate_answer(
59-
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
60-
input_sentence=description),
61-
temperature=1
62-
)
63-
)
64-
65-
new_descriptions = await asyncio.gather(*new_description_tasks)
66-
new_anti_descriptions = await asyncio.gather(*new_anti_description_tasks)
67-
68-
for new_description in new_descriptions:
69-
descriptions.append((new_description, 'yes'))
70-
for new_anti_description in new_anti_descriptions:
71-
descriptions.append((new_anti_description, 'no'))
72-
73-
descriptions = list(set(descriptions))
41+
descriptions = await rephrase_storage.get_by_id(des)
42+
if descriptions:
43+
return None
44+
45+
new_description = await teacher_llm_client.generate_answer(
46+
prompt,
47+
temperature=1
48+
)
49+
return {des: [(new_description, gt)]}
50+
7451
except Exception as e: # pylint: disable=broad-except
7552
logger.error("Error when quizzing edge %s -> %s: %s", source_id, target_id, e)
76-
descriptions = [(description, 'yes')]
53+
return None
54+
7755

78-
await rephrase_storage.upsert({description: descriptions})
56+
edges = await graph_storage.get_all_edges()
7957

80-
return {description: descriptions}
58+
results = defaultdict(list)
59+
tasks = []
60+
for edge in edges:
61+
edge_data = edge[2]
8162

63+
description = edge_data["description"]
64+
language = "English" if detect_main_language(description) == "en" else "Chinese"
65+
66+
results[description] = [(description, 'yes')]
67+
68+
for i in range(max_samples):
69+
if i > 0:
70+
tasks.append(
71+
_quiz_single_relation(edge, description,
72+
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
73+
input_sentence=description), 'yes')
74+
)
75+
tasks.append(_quiz_single_relation(edge, description,
76+
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
77+
input_sentence=description), 'no'))
8278

83-
edges = await graph_storage.get_all_edges()
8479

85-
results = []
8680
for result in tqdm_async(
87-
asyncio.as_completed([_quiz_single_relation(edge) for edge in edges]),
88-
total=len(edges),
81+
asyncio.as_completed(tasks),
82+
total=len(tasks),
8983
desc="Quizzing relations"
9084
):
91-
results.append(await result)
85+
new_result = await result
86+
if new_result:
87+
for key, value in new_result.items():
88+
results[key].extend(value)
89+
90+
for key, value in results.items():
91+
results[key] = list(set(value))
92+
await rephrase_storage.upsert({key: results[key]})
9293

9394
return rephrase_storage

0 commit comments

Comments
 (0)