Skip to content

Commit db7a071

Browse files
feat(graphgen): calculate node loss
1 parent c3ef35b commit db7a071

File tree

10 files changed

+138
-42
lines changed

10 files changed

+138
-42
lines changed

generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
graph_gen.insert(data, args.data_type)
7575

76-
graph_gen.quiz(max_samples=3)
76+
graph_gen.quiz(max_samples=2)
7777

7878
graph_gen.judge(re_judge=False)
7979

graphgen/graphgen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
1111
from models.storage.base_storage import StorageNameSpace
1212
from utils import create_event_loop, logger, compute_content_hash
13-
from .operators import extract_kg, search_wikipedia, quiz_relations, judge_relations, traverse_graph_by_edge
13+
from .operators import extract_kg, search_wikipedia, quiz, judge_statement, traverse_graph_by_edge
1414

1515

1616
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -169,15 +169,15 @@ def quiz(self, max_samples=1):
169169
loop.run_until_complete(self.async_quiz(max_samples))
170170

171171
async def async_quiz(self, max_samples=1):
172-
await quiz_relations(self.teacher_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
172+
await quiz(self.teacher_llm_client, self.graph_storage, self.rephrase_storage, max_samples)
173173
await self.rephrase_storage.index_done_callback()
174174

175175
def judge(self, re_judge=False):
176176
loop = create_event_loop()
177177
loop.run_until_complete(self.async_judge(re_judge))
178178

179179
async def async_judge(self, re_judge=False):
180-
_update_relations = await judge_relations(self.student_llm_client, self.graph_storage,
180+
_update_relations = await judge_statement(self.student_llm_client, self.graph_storage,
181181
self.rephrase_storage, re_judge)
182182
await _update_relations.index_done_callback()
183183

graphgen/operators/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from .extract_kg import extract_kg
2-
from .quiz_relations import quiz_relations
3-
from .judge_relations import judge_relations
2+
from .quiz import quiz
3+
from .judge import judge_statement
44
from .search_wikipedia import search_wikipedia
55
from .traverse_graph import traverse_graph_by_edge
66

77
__all__ = [
88
"extract_kg",
9-
"quiz_relations",
10-
"judge_relations",
9+
"quiz",
10+
"judge_statement",
1111
"search_wikipedia",
1212
"traverse_graph_by_edge"
1313
]
Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from templates import STATEMENT_JUDGEMENT_PROMPT
77

88

9-
async def judge_relations(
9+
async def judge_statement(
1010
student_llm_client: OpenAIModel,
1111
graph_storage: NetworkXStorage,
1212
rephrase_storage: JsonKVStorage,
1313
re_judge: bool = False,
1414
max_concurrent: int = 1000) -> NetworkXStorage:
1515
"""
16-
Get all edges and judge them
16+
Get all edges and nodes and judge them
1717
1818
:param student_llm_client: judge the statements to get comprehension loss
1919
:param graph_storage: graph storage instance
@@ -74,4 +74,52 @@ async def _judge_single_relation(
7474
):
7575
results.append(await result)
7676

77+
async def _judge_single_entity(
78+
node: tuple,
79+
):
80+
async with semaphore:
81+
node_id = node[0]
82+
node_data = node[1]
83+
84+
if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
85+
logger.info("Node %s already judged, loss: %s, skip", node_id, node_data["loss"])
86+
return node_id, node_data
87+
88+
description = node_data["description"]
89+
90+
try:
91+
descriptions = await rephrase_storage.get_by_id(description)
92+
assert descriptions is not None
93+
94+
judgements = []
95+
gts = [gt for _, gt in descriptions]
96+
for description, gt in descriptions:
97+
judgement = await student_llm_client.generate_topk_per_token(
98+
STATEMENT_JUDGEMENT_PROMPT['TEMPLATE'].format(statement=description)
99+
)
100+
judgements.append(judgement[0].top_candidates)
101+
102+
loss = yes_no_loss_entropy(judgements, gts)
103+
104+
logger.info("Node %s description: %s loss: %s", node_id, description, loss)
105+
106+
node_data["loss"] = loss
107+
except Exception as e: # pylint: disable=broad-except
108+
logger.error("Error in judging entity %s: %s", node_id, e)
109+
logger.info("Use default loss 0.1")
110+
node_data["loss"] = -math.log(0.1)
111+
112+
await graph_storage.update_node(node_id, node_data)
113+
return node_id, node_data
114+
115+
nodes = await graph_storage.get_all_nodes()
116+
117+
results = []
118+
for result in tqdm_async(
119+
asyncio.as_completed([_judge_single_entity(node) for node in nodes]),
120+
total=len(nodes),
121+
desc="Judging entities"
122+
):
123+
results.append(await result)
124+
77125
return graph_storage
Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from templates import DESCRIPTION_REPHRASING_PROMPT
88

99

10-
async def quiz_relations(
10+
async def quiz(
1111
teacher_llm_client: OpenAIModel,
1212
graph_storage: NetworkXStorage,
1313
rephrase_storage: JsonKVStorage,
@@ -26,16 +26,12 @@ async def quiz_relations(
2626

2727
semaphore = asyncio.Semaphore(max_concurrent)
2828

29-
async def _quiz_single_relation(
30-
edge: tuple,
29+
async def _process_single_quiz(
3130
des: str,
3231
prompt: str,
3332
gt: str
3433
):
3534
async with semaphore:
36-
source_id = edge[0]
37-
target_id = edge[1]
38-
3935
try:
4036
# 如果在rephrase_storage中已经存在,直接取出
4137
descriptions = await rephrase_storage.get_by_id(des)
@@ -49,11 +45,12 @@ async def _quiz_single_relation(
4945
return {des: [(new_description, gt)]}
5046

5147
except Exception as e: # pylint: disable=broad-except
52-
logger.error("Error when quizzing edge %s -> %s: %s", source_id, target_id, e)
48+
logger.error("Error when quizzing description %s: %s", des, e)
5349
return None
5450

5551

5652
edges = await graph_storage.get_all_edges()
53+
nodes = await graph_storage.get_all_nodes()
5754

5855
results = defaultdict(list)
5956
tasks = []
@@ -68,19 +65,36 @@ async def _quiz_single_relation(
6865
for i in range(max_samples):
6966
if i > 0:
7067
tasks.append(
71-
_quiz_single_relation(edge, description,
68+
_process_single_quiz(description,
7269
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
7370
input_sentence=description), 'yes')
7471
)
75-
tasks.append(_quiz_single_relation(edge, description,
72+
tasks.append(_process_single_quiz(description,
7673
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
7774
input_sentence=description), 'no'))
7875

76+
for node in nodes:
77+
node_data = node[1]
78+
description = node_data["description"]
79+
language = "English" if detect_main_language(description) == "en" else "Chinese"
80+
81+
results[description] = [(description, 'yes')]
82+
83+
for i in range(max_samples):
84+
if i > 0:
85+
tasks.append(
86+
_process_single_quiz(description,
87+
DESCRIPTION_REPHRASING_PROMPT[language]['TEMPLATE'].format(
88+
input_sentence=description), 'yes')
89+
)
90+
tasks.append(_process_single_quiz(description,
91+
DESCRIPTION_REPHRASING_PROMPT[language]['ANTI_TEMPLATE'].format(
92+
input_sentence=description), 'no'))
7993

8094
for result in tqdm_async(
8195
asyncio.as_completed(tasks),
8296
total=len(tasks),
83-
desc="Quizzing relations"
97+
desc="Quizzing descriptions"
8498
):
8599
new_result = await result
86100
if new_result:
@@ -91,4 +105,5 @@ async def _quiz_single_relation(
91105
results[key] = list(set(value))
92106
await rephrase_storage.upsert({key: results[key]})
93107

108+
94109
return rephrase_storage

graphgen/operators/split_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ async def get_cached_node_info(node_id: str) -> dict:
260260

261261
processing_batches.append((_process_nodes, _process_edges))
262262

263+
l
263264
# isolate nodes
264265
isolated_node_strategy = traverse_strategy.isolated_node_strategy
265266
if isolated_node_strategy == "add":

graphgen/operators/traverse_graph.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def get_loss_tercile(losses: list) -> (float, float):
5757

5858
return losses[q1_index], losses[q2_index]
5959

60+
def get_average_loss(batch: tuple) -> float:
61+
return sum(edge[2]['loss'] for edge in batch[1]) + sum(node['loss'] for node in batch[0]) / \
62+
(len(batch[0]) + len(batch[1]))
63+
6064
async def traverse_graph_by_edge(
6165
llm_client: OpenAIModel,
6266
tokenizer: Tokenizer,
@@ -114,8 +118,6 @@ async def _process_single_batch(
114118
_process_batch: tuple
115119
) -> dict:
116120
async with semaphore:
117-
losses = [(edge[0], edge[1], edge[2]['loss']) for edge in _process_batch[1]]
118-
119121
context = await _process_nodes_and_edges(
120122
_process_batch[0],
121123
_process_batch[1],
@@ -145,14 +147,14 @@ async def _process_single_batch(
145147
compute_content_hash(context): {
146148
"question": question,
147149
"answer": context,
148-
"losses": losses,
150+
"loss": get_average_loss(_process_batch),
149151
"difficulty": _process_batch[2],
150152
}
151153
}
152154

153155
results = {}
154156
edges = list(await graph_storage.get_all_edges())
155-
nodes = await graph_storage.get_all_nodes()
157+
nodes = list(await graph_storage.get_all_nodes())
156158

157159
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
158160

@@ -165,18 +167,13 @@ async def _process_single_batch(
165167

166168
losses = []
167169
for batch in processing_batches:
168-
if len(batch[1]) == 0:
169-
continue
170-
loss = sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
170+
loss = get_average_loss(batch)
171171
losses.append(loss)
172172
q1, q2 = get_loss_tercile(losses)
173173

174174
difficulty_order = traverse_strategy.difficulty_order
175175
for i, batch in enumerate(processing_batches):
176-
if len(batch[1]) == 0:
177-
processing_batches[i] = (batch[0], batch[1], difficulty_order[0])
178-
continue
179-
loss = sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
176+
loss = get_average_loss(batch)
180177
if loss < q1:
181178
# easy
182179
processing_batches[i] = (batch[0], batch[1], difficulty_order[0])

judge.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
import argparse
3+
import asyncio
4+
from dotenv import load_dotenv
5+
6+
from models import NetworkXStorage, JsonKVStorage, OpenAIModel
7+
from graphgen.operators import judge_relations
8+
9+
sys_path = os.path.abspath(os.path.dirname(__file__))
10+
11+
load_dotenv()
12+
13+
14+
if __name__ == '__main__':
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--output', type=str, default='cache/output/new_graph.graphml', help='path to save output')
17+
18+
args = parser.parse_args()
19+
20+
llm_client = OpenAIModel(
21+
model_name=os.getenv("STUDENT_MODEL"),
22+
api_key=os.getenv("STUDENT_API_KEY"),
23+
base_url=os.getenv("STUDENT_BASE_URL")
24+
)
25+
26+
graph_storage = NetworkXStorage(
27+
os.path.join(sys_path, "cache"),
28+
namespace="graph"
29+
)
30+
31+
rephrase_storage = JsonKVStorage(
32+
os.path.join(sys_path, "cache"),
33+
namespace="rephrase"
34+
)
35+
36+
new_graph = asyncio.run(judge_relations(llm_client, graph_storage, rephrase_storage, re_judge=True))
37+
38+
graph_file = asyncio.run(graph_storage.get_graph())
39+
40+
new_graph.write_nx_graph(graph_file, args.output)

models/storage/base_storage.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import numpy as np
2-
31
from dataclasses import dataclass
4-
from models.embed.embedding import EmbeddingFunc
52
from typing import Union, Generic, TypeVar
3+
from models.embed.embedding import EmbeddingFunc
64

75
T = TypeVar("T")
86

@@ -95,6 +93,3 @@ async def upsert_edge(
9593

9694
async def delete_node(self, node_id: str):
9795
raise NotImplementedError
98-
99-
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
100-
raise NotImplementedError("Node embedding is not used in lightrag.")

models/strategy/travserse_strategy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@ class TraverseStrategy(BaseStrategy):
88
# 最大边数和最大token数方法中选择一个生效
99
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
1010
# 单向拓展还是双向拓展
11-
bidirectional: bool = False
11+
bidirectional: bool = True
1212
# 每个方向拓展的最大边数
1313
max_extra_edges: int = 5
1414
# 最长token数
15-
max_tokens: int = 1024
15+
max_tokens: int = 512
1616
# 每个方向拓展的最大深度
17-
max_depth: int = 3
17+
max_depth: int = 5
1818
# 同一层中选边的策略(如果是双向拓展,同一层指的是两边连接的边的集合)
1919
edge_sampling: str = "max_loss" # "max_loss" or "min_loss" or "random"
2020
# 孤立节点的处理策略
21-
isolated_node_strategy: str = "ignore" # "add" or "ignore"
21+
isolated_node_strategy: str = "add" # "add" or "ignore"
2222
# 难度顺序 ["easy", "medium", "hard"], ["hard", "medium", "easy"], ["medium", "medium", "medium"]
23-
difficulty_order: list = field(default_factory=lambda: ["easy", "medium", "hard"])
23+
difficulty_order: list = field(default_factory=lambda: ["medium", "medium", "medium"])
2424

2525
def to_yaml(self):
2626
return {

0 commit comments

Comments
 (0)