Skip to content

Commit 1636a28

Browse files
fix(graphgen): accelerate quizzing
1 parent af23e20 commit 1636a28

File tree

5 files changed

+30
-14
lines changed

5 files changed

+30
-14
lines changed

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575

7676
graph_gen.quiz(max_samples=3)
7777

78-
graph_gen.judge(re_judge=True)
78+
graph_gen.judge(re_judge=False)
7979

8080
graph_gen.traverse()
8181
with open(os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml"), "w", encoding='utf-8') as f:

graphgen/operators/quiz_relations.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,37 @@ async def _quiz_single_relation(
4242
if not descriptions:
4343
# 多次采样,取平均
4444
descriptions = [(description, 'yes')]
45+
46+
new_description_tasks = []
47+
new_anti_description_tasks = []
4548
for i in range(max_samples):
4649
if i > 0:
47-
new_description = await teacher_llm_client.generate_answer(
48-
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(input_sentence=description),
50+
new_description_tasks.append(
51+
teacher_llm_client.generate_answer(
52+
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
53+
input_sentence=description),
54+
temperature=1
55+
)
56+
)
57+
new_anti_description_tasks.append(
58+
teacher_llm_client.generate_answer(
59+
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
60+
input_sentence=description),
4961
temperature=1
5062
)
51-
descriptions.append((new_description, 'yes'))
52-
new_anti_description = await teacher_llm_client.generate_answer(
53-
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(input_sentence=description),
54-
temperature=1
5563
)
64+
65+
new_descriptions = await asyncio.gather(*new_description_tasks)
66+
new_anti_descriptions = await asyncio.gather(*new_anti_description_tasks)
67+
68+
for new_description in new_descriptions:
69+
descriptions.append((new_description, 'yes'))
70+
for new_anti_description in new_anti_descriptions:
5671
descriptions.append((new_anti_description, 'no'))
5772

5873
descriptions = list(set(descriptions))
5974
except Exception as e: # pylint: disable=broad-except
60-
logger.error(f"Error when quizzing edge {source_id} -> {target_id}: {e}")
75+
logger.error("Error when quizzing edge %s -> %s: %s", source_id, target_id, e)
6176
descriptions = [(description, 'yes')]
6277

6378
await rephrase_storage.upsert({description: descriptions})

models/llm/tokenizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
import tiktoken
21
from dataclasses import dataclass
32
from typing import List
3+
import tiktoken
44

55
try:
66
from transformers import AutoTokenizer
77
TRANSFORMERS_AVAILABLE = True
88
except ImportError:
9+
AutoTokenizer = None
910
TRANSFORMERS_AVAILABLE = False
1011

1112

@@ -18,11 +19,11 @@ def get_tokenizer(tokenizer_name: str = "cl100k_base"):
1819
"""
1920
if tokenizer_name in tiktoken.list_encoding_names():
2021
return tiktoken.get_encoding(tokenizer_name)
21-
elif TRANSFORMERS_AVAILABLE:
22+
if TRANSFORMERS_AVAILABLE:
2223
try:
2324
return AutoTokenizer.from_pretrained(tokenizer_name)
2425
except Exception as e:
25-
raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}")
26+
raise ValueError(f"Failed to load tokenizer from Hugging Face: {e}") from e
2627
else:
2728
raise ValueError("Hugging Face Transformers is not available, please install it first.")
2829

models/llm/topk_token_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ class TopkTokenModel:
2020
do_sample: bool = False
2121
temperature: float = 0
2222
max_tokens: int = 10240
23-
repetition_penalty: float = 1.0
23+
repetition_penalty: float = 1.05
2424
num_beams: int = 1
2525
topk: int = 50
26-
topp: float = 0.1
26+
topp: float = 0.95
2727

2828
topk_per_token: int = 5 # number of topk tokens to generate for each token
2929

models/strategy/travserse_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class TraverseStrategy(BaseStrategy):
1818
# 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
1919
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
2020
# 孤立节点的处理策略
21-
isolated_node_strategy: str = "add" # "add" or "ignore"
21+
isolated_node_strategy: str = "ignore" # "add" or "ignore"
2222
# 难度顺序 ["easy", "medium", "hard"], ["hard", "medium", "easy"], ["medium", "medium", "medium"]
2323
difficulty_order: list = field(default_factory=lambda: ["easy", "medium", "hard"])
2424

0 commit comments

Comments
 (0)