|
1 | 1 | import asyncio |
| 2 | +from collections import defaultdict |
2 | 3 |
|
3 | 4 | from tqdm.asyncio import tqdm as tqdm_async |
4 | 5 | from models import JsonKVStorage, OpenAIModel, NetworkXStorage |
@@ -27,67 +28,67 @@ async def quiz_relations( |
27 | 28 |
|
28 | 29 | async def _quiz_single_relation( |
29 | 30 | edge: tuple, |
| 31 | + des: str, |
| 32 | + prompt: str, |
| 33 | + gt: str |
30 | 34 | ): |
31 | 35 | async with semaphore: |
32 | 36 | source_id = edge[0] |
33 | 37 | target_id = edge[1] |
34 | | - edge_data = edge[2] |
35 | | - |
36 | | - description = edge_data["description"] |
37 | | - language = "English" if detect_main_language(description) == "en" else "Chinese" |
38 | 38 |
|
39 | 39 | try: |
40 | 40 | # 如果在rephrase_storage中已经存在,直接取出 |
41 | | - descriptions = await rephrase_storage.get_by_id(description) |
42 | | - if not descriptions: |
43 | | - # 多次采样,取平均 |
44 | | - descriptions = [(description, 'yes')] |
45 | | - |
46 | | - new_description_tasks = [] |
47 | | - new_anti_description_tasks = [] |
48 | | - for i in range(max_samples): |
49 | | - if i > 0: |
50 | | - new_description_tasks.append( |
51 | | - teacher_llm_client.generate_answer( |
52 | | - DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( |
53 | | - input_sentence=description), |
54 | | - temperature=1 |
55 | | - ) |
56 | | - ) |
57 | | - new_anti_description_tasks.append( |
58 | | - teacher_llm_client.generate_answer( |
59 | | - DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( |
60 | | - input_sentence=description), |
61 | | - temperature=1 |
62 | | - ) |
63 | | - ) |
64 | | - |
65 | | - new_descriptions = await asyncio.gather(*new_description_tasks) |
66 | | - new_anti_descriptions = await asyncio.gather(*new_anti_description_tasks) |
67 | | - |
68 | | - for new_description in new_descriptions: |
69 | | - descriptions.append((new_description, 'yes')) |
70 | | - for new_anti_description in new_anti_descriptions: |
71 | | - descriptions.append((new_anti_description, 'no')) |
72 | | - |
73 | | - descriptions = list(set(descriptions)) |
| 41 | + descriptions = await rephrase_storage.get_by_id(des) |
| 42 | + if descriptions: |
| 43 | + return None |
| 44 | + |
| 45 | + new_description = await teacher_llm_client.generate_answer( |
| 46 | + prompt, |
| 47 | + temperature=1 |
| 48 | + ) |
| 49 | + return {des: [(new_description, gt)]} |
| 50 | + |
74 | 51 | except Exception as e: # pylint: disable=broad-except |
75 | 52 | logger.error("Error when quizzing edge %s -> %s: %s", source_id, target_id, e) |
76 | | - descriptions = [(description, 'yes')] |
| 53 | + return None |
| 54 | + |
77 | 55 |
|
78 | | - await rephrase_storage.upsert({description: descriptions}) |
| 56 | + edges = await graph_storage.get_all_edges() |
79 | 57 |
|
80 | | - return {description: descriptions} |
| 58 | + results = defaultdict(list) |
| 59 | + tasks = [] |
| 60 | + for edge in edges: |
| 61 | + edge_data = edge[2] |
81 | 62 |
|
| 63 | + description = edge_data["description"] |
| 64 | + language = "English" if detect_main_language(description) == "en" else "Chinese" |
| 65 | + |
| 66 | + results[description] = [(description, 'yes')] |
| 67 | + |
| 68 | + for i in range(max_samples): |
| 69 | + if i > 0: |
| 70 | + tasks.append( |
| 71 | + _quiz_single_relation(edge, description, |
| 72 | + DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format( |
| 73 | + input_sentence=description), 'yes') |
| 74 | + ) |
| 75 | + tasks.append(_quiz_single_relation(edge, description, |
| 76 | + DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format( |
| 77 | + input_sentence=description), 'no')) |
82 | 78 |
|
83 | | - edges = await graph_storage.get_all_edges() |
84 | 79 |
|
85 | | - results = [] |
86 | 80 | for result in tqdm_async( |
87 | | - asyncio.as_completed([_quiz_single_relation(edge) for edge in edges]), |
88 | | - total=len(edges), |
| 81 | + asyncio.as_completed(tasks), |
| 82 | + total=len(tasks), |
89 | 83 | desc="Quizzing relations" |
90 | 84 | ): |
91 | | - results.append(await result) |
| 85 | + new_result = await result |
| 86 | + if new_result: |
| 87 | + for key, value in new_result.items(): |
| 88 | + results[key].extend(value) |
| 89 | + |
| 90 | + for key, value in results.items(): |
| 91 | + results[key] = list(set(value)) |
| 92 | + await rephrase_storage.upsert({key: results[key]}) |
92 | 93 |
|
93 | 94 | return rephrase_storage |
0 commit comments