Skip to content

Commit 7ae7955

Browse files
feat(graphgen): multi-sample when judging
1 parent ca191f3 commit 7ae7955

File tree

9 files changed

+143
-55
lines changed

9 files changed

+143
-55
lines changed

baselines/Genie/genie.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
# https://arxiv.org/pdf/2401.14367
2+
23
import os
34
import json
45
import argparse
56
import asyncio
6-
7+
from typing import List
78
from dataclasses import dataclass
9+
from tqdm.asyncio import tqdm as tqdm_async
810
from dotenv import load_dotenv
11+
912
from models import OpenAIModel
10-
from typing import List
1113
from utils import create_event_loop, compute_content_hash
12-
from tqdm.asyncio import tqdm as tqdm_async
1314

14-
PROMPT_TEMPLATE = '''Instruction: Given the next [document], create a [question] and [answer] pair that are grounded in the main point of the document, don't add any additional information that is not in the document. The [question] is by an information-seeking user and the [answer] is provided by a helping AI Agent.
15+
PROMPT_TEMPLATE = '''Instruction: Given the next [document], create a [question] and [answer] pair that are grounded \
16+
in the main point of the document, don't add any additional information that is not in the document. The [question] is \
17+
by an information-seeking user and the [answer] is provided by a helping AI Agent.
1518
1619
[document]: Scrumptious Sweet Co. factory ...
1720
@@ -23,13 +26,16 @@
2326
2427
### Response:
2528
[question]: What is the plot of the show Schitt's Creek?
26-
[answer]: The show Schitt's Creek is about a wealthy family who loses their fortune and is forced to rebuild their lives in a small town. The show follows the family as they adjust to their new life in the town and learn to appreciate the simple things in life.
29+
[answer]: The show Schitt's Creek is about a wealthy family who loses their fortune and is forced to rebuild their \
30+
lives in a small town. The show follows the family as they adjust to their new life in the town and learn to \
31+
appreciate the simple things in life.
2732
2833
[document]: 2016's countdown broke several Hottest 100 records ...
2934
3035
### Response:
3136
[question]: What was the most popular song on the 2016 Hottest 100?
32-
[answer]: The most popular song on the 2016 Hottest 100 was "Never Be Like You" by Flume. This was the first time that an electronic dance music producer topped the countdown.
37+
[answer]: The most popular song on the 2016 Hottest 100 was "Never Be Like You" by Flume. This was the first time that \
38+
an electronic dance music producer topped the countdown.
3339
3440
[document]: In Greek mythology, Persephone ...
3541
@@ -79,7 +85,7 @@ async def process_chunk(content: str):
7985
'question': question,
8086
'answer': answer
8187
}
82-
except Exception as e:
88+
except Exception as e: # pylint: disable=broad-except
8389
print(f"Error: {e}")
8490
return final_results
8591

@@ -112,15 +118,15 @@ async def process_chunk(content: str):
112118
genie = Genie(llm_client=llm_client)
113119

114120
if args.data_type == 'raw':
115-
with open(args.input_file, "r") as f:
121+
with open(args.input_file, "r", encoding='utf-8') as f:
116122
data = [json.loads(line) for line in f]
117123
data = [[chunk] for chunk in data]
118124
elif args.data_type == 'chunked':
119-
with open(args.input_file, "r") as f:
125+
with open(args.input_file, "r", encoding='utf-8') as f:
120126
data = json.load(f)
121127

122128
results = genie.generate(data)
123129

124130
# Save results
125-
with open(args.output_file, "w") as f:
131+
with open(args.output_file, "w", encoding='utf-8') as f:
126132
json.dump(results, f, indent=4, ensure_ascii=False)

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
graph_gen.insert(data, args.data_type)
7575

76-
graph_gen.judge(re_judge=False)
76+
graph_gen.judge(re_judge=True, max_samples=3)
7777

7878
graph_gen.traverse()
7979
with open(os.path.join(sys_path, "cache", "configs", f"graphgen_{unique_id}.yaml"), "w", encoding='utf-8') as f:

graphgen/graphgen.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class GraphGen:
3232
graph_storage: NetworkXStorage = NetworkXStorage(
3333
working_dir, namespace="graph"
3434
)
35+
rephrase_storage: JsonKVStorage = JsonKVStorage(
36+
working_dir, namespace="rephrase"
37+
)
3538
qa_storage: JsonKVStorage = JsonKVStorage(
3639
os.path.join(working_dir, "data", "graphgen"), namespace=f"qa-{unique_id}"
3740
)
@@ -159,19 +162,22 @@ async def _insert_done(self):
159162
tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback())
160163
await asyncio.gather(*tasks)
161164

162-
def judge(self, re_judge=False):
165+
def judge(self, re_judge=False, max_samples=1):
163166
loop = create_event_loop()
164-
loop.run_until_complete(self.async_judge(re_judge))
167+
loop.run_until_complete(self.async_judge(re_judge, max_samples))
165168

166-
async def async_judge(self, re_judge=False):
167-
_update_relations = await judge_relations(self.teacher_llm_client, self.student_llm_client, self.graph_storage, re_judge)
169+
async def async_judge(self, re_judge=False, max_samples=1):
170+
_update_relations = await judge_relations(self.teacher_llm_client, self.student_llm_client,
171+
self.graph_storage, self.rephrase_storage, re_judge, max_samples)
168172
await _update_relations.index_done_callback()
173+
await self.rephrase_storage.index_done_callback()
169174

170175
def traverse(self):
171176
loop = create_event_loop()
172177
loop.run_until_complete(self.async_traverse())
173178

174179
async def async_traverse(self):
175-
results = await traverse_graph_by_edge(self.teacher_llm_client, self.tokenizer_instance, self.graph_storage, self.traverse_strategy)
180+
results = await traverse_graph_by_edge(self.teacher_llm_client, self.tokenizer_instance,
181+
self.graph_storage, self.traverse_strategy)
176182
await self.qa_storage.upsert(results)
177183
await self.qa_storage.index_done_callback()
Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
1+
import math
12
import asyncio
2-
from models import NetworkXStorage
3-
from utils import logger, yes_no_loss, detect_main_language
4-
from templates import ANTI_DESCRIPTION_REPHRASING_PROMPT, STATEMENT_JUDGEMENT_PROMPT
5-
from models import OpenAIModel
63
from tqdm.asyncio import tqdm as tqdm_async
4+
from models import NetworkXStorage, OpenAIModel, JsonKVStorage
5+
from utils import logger, yes_no_loss_entropy, detect_main_language
6+
from templates import DESCRIPTION_REPHRASING_PROMPT, STATEMENT_JUDGEMENT_PROMPT
77

88

99
async def judge_relations(
1010
teacher_llm_client: OpenAIModel,
1111
student_llm_client: OpenAIModel,
1212
graph_storage: NetworkXStorage,
13+
rephrase_storage: JsonKVStorage,
1314
re_judge: bool = False,
15+
max_samples: int = 1,
1416
max_concurrent: int = 1000) -> NetworkXStorage:
1517
"""
1618
Get all edges and judge them
1719
1820
:param teacher_llm_client: generate statements
1921
:param student_llm_client: judge the statements to get comprehension loss
2022
:param graph_storage: graph storage instance
23+
:param rephrase_storage: rephrase storage instance
2124
:param re_judge: re-judge the relations
25+
:param max_samples: max samples for each edge
2226
:param max_concurrent: max concurrent
2327
:return:
2428
"""
@@ -38,33 +42,48 @@ async def _judge_single_relation(
3842
return source_id, target_id, edge_data
3943

4044
description = edge_data["description"]
41-
4245
language = "English" if detect_main_language(description) == "en" else "Chinese"
4346

4447
try:
45-
anti_description = await teacher_llm_client.generate_answer(
46-
ANTI_DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(input_sentence=description)
47-
)
48-
49-
judgement = await student_llm_client.generate_topk_per_token(
50-
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
51-
)
52-
anti_judgement = await student_llm_client.generate_topk_per_token(
53-
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=anti_description)
54-
)
55-
56-
loss = yes_no_loss(
57-
[judgement[0].top_candidates, anti_judgement[0].top_candidates],
58-
['yes', 'no']
59-
)
60-
61-
logger.info(f"Edge {source_id} -> {target_id} description: {description} loss: {loss}")
48+
# 如果在rephrase_storage中已经存在,直接取出
49+
descriptions = await rephrase_storage.get_by_id(description)
50+
if not descriptions:
51+
# 多次采样,取平均
52+
descriptions = [(description, 'yes')]
53+
for i in range(max_samples):
54+
if i > 0:
55+
new_description = await teacher_llm_client.generate_answer(
56+
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(input_sentence=description),
57+
temperature=1
58+
)
59+
descriptions.append((new_description, 'yes'))
60+
new_anti_description = await teacher_llm_client.generate_answer(
61+
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(input_sentence=description),
62+
temperature=1
63+
)
64+
descriptions.append((new_anti_description, 'no'))
65+
66+
descriptions = list(set(descriptions))
67+
68+
await rephrase_storage.upsert({description: descriptions})
69+
70+
judgements = []
71+
gts = [gt for _, gt in descriptions]
72+
for description, gt in descriptions:
73+
judgement = await student_llm_client.generate_topk_per_token(
74+
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
75+
)
76+
judgements.append(judgement[0].top_candidates)
77+
78+
loss = yes_no_loss_entropy(judgements, gts)
79+
80+
logger.info("Edge %s -> %s description: %s loss: %s", source_id, target_id, description, loss)
6281

6382
edge_data["loss"] = loss
64-
except Exception as e:
83+
except Exception as e: # pylint: disable=broad-except
6584
logger.error(f"Error in judging relation {source_id} -> {target_id}: {e}")
6685
logger.info("Use default loss 0.1")
67-
edge_data["loss"] = 0.1
86+
edge_data["loss"] = -math.log(0.1)
6887

6988
await graph_storage.update_edge(source_id, target_id, edge_data)
7089
return source_id, target_id, edge_data
@@ -77,6 +96,6 @@ async def _judge_single_relation(
7796
total=len(edges),
7897
desc="Judging relations"
7998
):
80-
results.append(await result)
99+
results.append(await result)
81100

82101
return graph_storage

models/llm/openai_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
7171
kwargs["logprobs"] = True
7272
kwargs["top_logprobs"] = self.topk_per_token
7373

74+
# Limit max_tokens to 2 to avoid long completions
75+
kwargs["max_tokens"] = 2
76+
7477
completion = await self.client.chat.completions.create(
7578
model=self.model_name,
7679
**kwargs
@@ -85,8 +88,9 @@ async def generate_topk_per_token(self, text: str, history: Optional[List[str]]
8588
wait=wait_exponential(multiplier=1, min=4, max=10),
8689
retry=retry_if_exception_type((RateLimitError, APIConnectionError, APITimeoutError)),
8790
)
88-
async def generate_answer(self, text: str, history: Optional[List[str]] = None) -> str:
91+
async def generate_answer(self, text: str, history: Optional[List[str]] = None, temperature: int = 0) -> str:
8992
kwargs = self._pre_generate(text, history)
93+
kwargs["temperature"] = temperature
9094

9195
completion = await self.client.chat.completions.create(
9296
model=self.model_name,

templates/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .kg_extraction import KG_EXTRACTION_PROMPT
22
from .kg_summarization import KG_SUMMARIZATION_PROMPT
33
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
4-
from .anti_description_rephrasing import ANTI_DESCRIPTION_REPHRASING_PROMPT
4+
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
55
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
66
from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
77
from .question_generation import QUESTION_GENERATION_PROMPT
Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
TEMPLATE_EN: str = """-Goal-
1+
ANTI_TEMPLATE_EN: str = """-Goal-
22
Transform the input sentence into its opposite meaning while:
33
44
1. Preserving most of the original sentence structure
@@ -25,7 +25,7 @@
2525
Output:
2626
"""
2727

28-
TEMPLATE_ZH: str = """-目标-
28+
ANTI_TEMPLATE_ZH: str = """-目标-
2929
将输入句子转换为相反含义的句子,同时:
3030
3131
1. 保留大部分原始句子结构
@@ -52,11 +52,66 @@
5252
输出:
5353
"""
5454

55-
ANTI_DESCRIPTION_REPHRASING_PROMPT= {
55+
TEMPLATE_ZH: str = """-目标-
56+
将输入句子转换为相同含义的句子,同时:
57+
58+
1. 保留大部分原始句子结构
59+
2. 仅更改影响核心含义的关键词
60+
3. 保持相同的语气和风格
61+
4. 输出句子应该流畅且语法正确
62+
63+
################
64+
-示例-
65+
################
66+
输入:
67+
明亮的阳光让每个人都感到充满活力和快乐。
68+
69+
输出:
70+
明媚的阳光让每个人都感受到活力与快乐。
71+
72+
################
73+
-真实数据-
74+
################
75+
输入:
76+
{input_sentence}
77+
################
78+
输出:
79+
"""
80+
81+
TEMPLATE_EN: str = """-Goal-
82+
Transform the input sentence into a sentence with the same meaning while:
83+
84+
1. Preserving most of the original sentence structure
85+
2. Changing only key words that affect the core meaning
86+
3. Maintaining the same tone and style
87+
4. The output sentence should be fluent and grammatically correct
88+
89+
################
90+
-Examples-
91+
################
92+
Input:
93+
The bright sunshine made everyone feel energetic and happy.
94+
95+
Output:
96+
The bright sunshine made everyone feel energetic and joyful.
97+
98+
################
99+
-Real Data-
100+
################
101+
Input:
102+
{input_sentence}
103+
################
104+
Output:
105+
"""
106+
107+
108+
DESCRIPTION_REPHRASING_PROMPT= {
56109
"English": {
110+
"ANTI_TEMPLATE": ANTI_TEMPLATE_EN,
57111
"TEMPLATE": TEMPLATE_EN
58112
},
59113
"Chinese": {
114+
"ANTI_TEMPLATE": ANTI_TEMPLATE_ZH,
60115
"TEMPLATE": TEMPLATE_ZH
61116
}
62117
}

utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
load_json, write_json)
66
from .hash import compute_content_hash, compute_args_hash
77
from .detect_lang import detect_main_language, detect_if_chinese
8-
from .calculate_confidence import yes_no_loss
8+
from .calculate_confidence import yes_no_loss_entropy
99
from .help_nltk import NLTKHelper

utils/calculate_confidence.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,29 @@ def preprocess_tokens(tokens: List[Token]) -> List[Token]:
1010
def joint_probability(tokens: List[Token]) -> float:
1111
"""Calculate joint probability of a list of tokens."""
1212
tokens = preprocess_tokens(tokens)
13-
logprob_sum = sum([x.logprob for x in tokens])
13+
logprob_sum = sum(x.logprob for x in tokens)
1414
return math.exp(logprob_sum / len(tokens))
1515

1616
def min_prob(tokens: List[Token]) -> float:
1717
"""Calculate the minimum probability of a list of tokens."""
1818
tokens = preprocess_tokens(tokens)
19-
return min([x.prob for x in tokens])
19+
return min(x.prob for x in tokens)
2020

2121
def average_prob(tokens: List[Token]) -> float:
2222
"""Calculate the average probability of a list of tokens."""
2323
tokens = preprocess_tokens(tokens)
24-
return sum([x.prob for x in tokens]) / len(tokens)
24+
return sum(x.prob for x in tokens) / len(tokens)
2525

2626
def average_confidence(tokens: List[Token]) -> float:
2727
"""Calculate the average confidence of a list of tokens."""
2828
tokens = preprocess_tokens(tokens)
29-
confidence = [x.prob / sum([y.prob for y in x.top_candidates[:5]]) for x in tokens]
29+
confidence = [x.prob / sum(y.prob for y in x.top_candidates[:5]) for x in tokens]
3030
return sum(confidence) / len(tokens)
3131

3232
def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
3333
"""Calculate the loss for yes/no question."""
3434
losses = []
35-
for i in range(len(tokens_list)):
36-
tokens = tokens_list[i]
35+
for i, tokens in enumerate(tokens_list):
3736
token = tokens[0]
3837
assert token.text in ["yes", "no"]
3938
if token.text == ground_truth[i]:
@@ -45,8 +44,7 @@ def yes_no_loss(tokens_list: List[List[Token]], ground_truth: List[str]) -> floa
4544
def yes_no_loss_entropy(tokens_list: List[List[Token]], ground_truth: List[str]) -> float:
4645
"""Calculate the loss for yes/no question using entropy."""
4746
losses = []
48-
for i in range(len(tokens_list)):
49-
tokens = tokens_list[i]
47+
for i, tokens in enumerate(tokens_list):
5048
token = tokens[0]
5149
assert token.text in ["yes", "no"]
5250
if token.text == ground_truth[i]:

0 commit comments

Comments
 (0)