Skip to content

Commit bbb66ad

Browse files
feat(templates): specify 3 levels of rephrasing difficulty
1 parent 3840d68 commit bbb66ad

File tree

5 files changed

+124
-23
lines changed

5 files changed

+124
-23
lines changed

graphgen/graphgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def judge(self, re_judge=False):
177177
loop.run_until_complete(self.async_judge(re_judge))
178178

179179
async def async_judge(self, re_judge=False):
180-
_update_relations = await judge_relations(self.teacher_llm_client, self.student_llm_client,
181-
self.graph_storage, self.rephrase_storage, re_judge)
180+
_update_relations = await judge_relations(self.student_llm_client, self.graph_storage,
181+
self.rephrase_storage, re_judge)
182182
await _update_relations.index_done_callback()
183183

184184
def traverse(self):

graphgen/operators/judge_relations.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import asyncio
33
from tqdm.asyncio import tqdm as tqdm_async
44
from models import NetworkXStorage, OpenAIModel, JsonKVStorage
5-
from utils import logger, yes_no_loss_entropy, detect_main_language
6-
from templates import DESCRIPTION_REPHRASING_PROMPT, STATEMENT_JUDGEMENT_PROMPT
5+
from utils import logger, yes_no_loss_entropy
6+
from templates import STATEMENT_JUDGEMENT_PROMPT
77

88

99
async def judge_relations(
10-
teacher_llm_client: OpenAIModel,
1110
student_llm_client: OpenAIModel,
1211
graph_storage: NetworkXStorage,
1312
rephrase_storage: JsonKVStorage,
@@ -16,7 +15,6 @@ async def judge_relations(
1615
"""
1716
Get all edges and judge them
1817
19-
:param teacher_llm_client: generate statements
2018
:param student_llm_client: judge the statements to get comprehension loss
2119
:param graph_storage: graph storage instance
2220
:param rephrase_storage: rephrase storage instance

graphgen/operators/split_graph.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import asyncio
21
import random
3-
42
from collections import defaultdict
5-
from models import NetworkXStorage, TraverseStrategy
63
from tqdm.asyncio import tqdm as tqdm_async
7-
from utils import logger, create_event_loop
4+
from utils import logger
5+
6+
from models import NetworkXStorage, TraverseStrategy
87

98

109
async def _get_node_info(
@@ -103,7 +102,7 @@ def _get_level_n_edges_by_max_tokens(
103102
) -> list:
104103
"""
105104
Get level n edges for an edge.
106-
n is decided by max_depth in traverse_strategy
105+
n is decided by max_depth in traverse_strategy.
107106
108107
:param edge_adj_list
109108
:param node_dict

graphgen/operators/traverse_graph.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ async def handle_node(node: dict) -> dict:
4949
await graph_storage.index_done_callback()
5050
return new_edges, new_nodes
5151

52+
53+
def get_loss_tercile(losses: list) -> (float, float):
54+
losses = sorted(losses)
55+
q1_index = int(len(losses) * (1 / 3))
56+
q2_index = int(len(losses) * (2 / 3))
57+
58+
return losses[q1_index], losses[q2_index]
59+
5260
async def traverse_graph_by_edge(
5361
llm_client: OpenAIModel,
5462
tokenizer: Tokenizer,
@@ -72,6 +80,7 @@ async def traverse_graph_by_edge(
7280
async def _process_nodes_and_edges(
7381
_process_nodes: list,
7482
_process_edges: list,
83+
_difficulty: str
7584
) -> str:
7685
entities = [
7786
f"{_process_node['node_id']}: {_process_node['description']}" for _process_node in _process_nodes
@@ -85,7 +94,7 @@ async def _process_nodes_and_edges(
8594
relations_str = "\n".join([f"{index + 1}. {relation}" for index, relation in enumerate(relations)])
8695

8796
language = "Chinese" if detect_main_language(entities_str + relations_str) == "zh" else "English"
88-
prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
97+
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['TEMPLATE'].format(
8998
language=language,
9099
entities=entities_str,
91100
relationships=relations_str
@@ -105,9 +114,12 @@ async def _process_single_batch(
105114
_process_batch: tuple
106115
) -> dict:
107116
async with semaphore:
117+
losses = [(edge[0], edge[1], edge[2]['loss']) for edge in _process_batch[1]]
118+
108119
context = await _process_nodes_and_edges(
109120
_process_batch[0],
110121
_process_batch[1],
122+
_process_batch[2]
111123
)
112124

113125
language = "Chinese" if detect_main_language(context) == "zh" else "English"
@@ -125,8 +137,6 @@ async def _process_single_batch(
125137
pre_length = sum(node['length'] for node in _process_batch[0]) \
126138
+ sum(edge[2]['length'] for edge in _process_batch[1])
127139

128-
losses = [(edge[0], edge[1], edge[2]['loss']) for edge in _process_batch[1]]
129-
130140
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
131141
logger.info("Pre-length: %s", pre_length)
132142
logger.info("Question: %s Answer: %s", question, context)
@@ -135,7 +145,8 @@ async def _process_single_batch(
135145
compute_content_hash(context): {
136146
"question": question,
137147
"answer": context,
138-
"losses": losses
148+
"losses": losses,
149+
"difficulty": _process_batch[2],
139150
}
140151
}
141152

@@ -152,6 +163,29 @@ async def _process_single_batch(
152163
traverse_strategy
153164
)
154165

166+
losses = []
167+
for batch in processing_batches:
168+
if len(batch[1]) == 0:
169+
continue
170+
loss = sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
171+
losses.append(loss)
172+
q1, q2 = get_loss_tercile(losses)
173+
174+
for i, batch in enumerate(processing_batches):
175+
if len(batch[1]) == 0:
176+
processing_batches[i] = (batch[0], batch[1], "easy")
177+
continue
178+
loss = sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
179+
if loss < q1:
180+
# easy
181+
processing_batches[i] = (batch[0], batch[1], "easy")
182+
elif loss < q2:
183+
# medium
184+
processing_batches[i] = (batch[0], batch[1], "medium")
185+
else:
186+
# hard
187+
processing_batches[i] = (batch[0], batch[1], "hard")
188+
155189
for result in tqdm_async(asyncio.as_completed(
156190
[_process_single_batch(batch) for batch in processing_batches]
157191
), total=len(processing_batches), desc="Processing batches"):

templates/answer_rephrasing.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,6 @@
4343
################
4444
{relationships}
4545
46-
################
47-
Please directly output the rephrased text below, without any additional content.
48-
49-
Rephrased Text:
5046
"""
5147

5248
TEMPLATE_ZH: str = """---角色---
@@ -92,18 +88,92 @@
9288
################
9389
{relationships}
9490
91+
"""
92+
93+
EASY_REQUIREMENT_EN = """
94+
---Requirements---
95+
- Requires a concise and straightforward summary, focusing on core meaning.
96+
- Uses simple language, avoiding complex sentence structures.
97+
- Does not need excessive details or examples; just the basic concepts and relationships.
98+
99+
################
100+
Please directly output the rephrased text below, without any additional content.
101+
102+
Rephrased Text:
103+
"""
104+
105+
EASY_REQUIREMENT_ZH = """
106+
---要求---
107+
- 要求简洁明了,主要传达核心意思。
108+
- 使用简单的语言,避免复杂的句子结构。
109+
- 不需要过多的细节或示例,只需基本概念和关系。
110+
111+
################
112+
请在下方直接输出重述文本,不要输出任何额外的内容。
113+
114+
重述文本:
115+
"""
116+
117+
MEDIUM_REQUIREMENT_ZH = """
95118
################
96119
请在下方直接输出重述文本,不要输出任何额外的内容。
97120
98121
重述文本:
99122
"""
100123

101124

125+
MEDIUM_REQUIREMENT_EN = """
126+
################
127+
Please directly output the rephrased text below, without any additional content.
128+
129+
Rephrased Text:
130+
"""
131+
132+
HARD_REQUIREMENT_EN = """
133+
---Requirements---
134+
- Requires an in-depth exploration of complex relationships and nuances.
135+
- Includes detailed background information, emphasizing logical consistency and complexity.
136+
137+
################
138+
Please directly output the rephrased text below, without any additional content.
139+
140+
Rephrased Text:
141+
"""
142+
143+
HARD_REQUIREMENT_ZH = """
144+
---要求---
145+
- 需要深入探讨复杂的关系和细微差别。
146+
- 包括详细的背景信息,强调逻辑一致性和复杂性。
147+
148+
################
149+
请在下方直接输出重述文本,不要输出任何额外的内容。
150+
151+
重述文本:
152+
"""
153+
102154
ANSWER_REPHRASING_PROMPT= {
103-
"English": {
104-
"TEMPLATE": TEMPLATE_EN
155+
"easy": {
156+
"English": {
157+
"TEMPLATE": TEMPLATE_EN + EASY_REQUIREMENT_EN
158+
},
159+
"Chinese": {
160+
"TEMPLATE": TEMPLATE_ZH + EASY_REQUIREMENT_ZH
161+
}
162+
},
163+
"medium": {
164+
"English": {
165+
"TEMPLATE": TEMPLATE_EN + MEDIUM_REQUIREMENT_EN
166+
},
167+
"Chinese": {
168+
"TEMPLATE": TEMPLATE_ZH + MEDIUM_REQUIREMENT_ZH
169+
}
105170
},
106-
"Chinese": {
107-
"TEMPLATE": TEMPLATE_ZH
171+
"hard": {
172+
"English": {
173+
"TEMPLATE": TEMPLATE_EN + HARD_REQUIREMENT_EN
174+
},
175+
"Chinese": {
176+
"TEMPLATE": TEMPLATE_ZH + HARD_REQUIREMENT_ZH
177+
}
108178
}
109179
}

0 commit comments

Comments
 (0)