Skip to content

Commit 2fbc3c2

Browse files
refactor: refactotr using LightRAGKGBuilder
1 parent b30f5a1 commit 2fbc3c2

File tree

9 files changed

+258
-201
lines changed

9 files changed

+258
-201
lines changed

graphgen/bases/base_kg_builder.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,13 @@
1010

1111
@dataclass
1212
class BaseKGBuilder(ABC):
13-
kg_instance: BaseGraphStorage
1413
llm_client: BaseLLMClient
1514

1615
_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
1716
_edges: Dict[Tuple[str, str], List[dict]] = field(
1817
default_factory=lambda: defaultdict(list)
1918
)
2019

21-
def build(self, chunks: List[Chunk]) -> None:
22-
pass
23-
24-
@abstractmethod
25-
async def extract_all(self, chunks: List[Chunk]) -> None:
26-
"""Extract nodes and edges from all chunks."""
27-
raise NotImplementedError
28-
2920
@abstractmethod
3021
async def extract(
3122
self, chunk: Chunk
@@ -35,7 +26,19 @@ async def extract(
3526

3627
@abstractmethod
3728
async def merge_nodes(
38-
self, nodes_data: Dict[str, List[dict]], kg_instance: BaseGraphStorage, llm
39-
) -> None:
29+
self,
30+
entity_name: str,
31+
node_data: Dict[str, List[dict]],
32+
kg_instance: BaseGraphStorage,
33+
) -> BaseGraphStorage:
4034
"""Merge extracted nodes into the knowledge graph."""
4135
raise NotImplementedError
36+
37+
@abstractmethod
38+
async def merge_edges(
39+
self,
40+
edges_data: Dict[Tuple[str, str], List[dict]],
41+
kg_instance: BaseGraphStorage,
42+
) -> BaseGraphStorage:
43+
"""Merge extracted edges into the knowledge graph."""
44+
raise NotImplementedError

graphgen/graphgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
TraverseStrategy,
1818
)
1919
from graphgen.operators import (
20+
build_kg,
2021
chunk_documents,
21-
extract_kg,
2222
generate_cot,
2323
judge_statement,
2424
quiz,
@@ -164,7 +164,7 @@ async def insert(self):
164164

165165
# Step 3: Extract entities and relations from chunks
166166
logger.info("[Entity and Relation Extraction]...")
167-
_add_entities_and_relations = await extract_kg(
167+
_add_entities_and_relations = await build_kg(
168168
llm_client=self.synthesizer_llm_client,
169169
kg_instance=self.graph_storage,
170170
tokenizer_instance=self.tokenizer_instance,

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .evaluate.mtld_evaluator import MTLDEvaluator
44
from .evaluate.reward_evaluator import RewardEvaluator
55
from .evaluate.uni_evaluator import UniEvaluator
6+
from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
67
from .llm.openai_client import OpenAIClient
78
from .llm.topk_token_model import TopkTokenModel
89
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader

graphgen/models/kg_builder/NetworkXKGBuilder.py

Lines changed: 0 additions & 18 deletions
This file was deleted.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import re
2+
from collections import defaultdict
3+
from dataclasses import dataclass
4+
from typing import Dict, List, Tuple
5+
6+
from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
7+
from graphgen.templates import KG_EXTRACTION_PROMPT
8+
from graphgen.utils import (
9+
detect_if_chinese,
10+
handle_single_entity_extraction,
11+
handle_single_relationship_extraction,
12+
logger,
13+
pack_history_conversations,
14+
split_string_by_multi_markers,
15+
)
16+
17+
18+
@dataclass
19+
class LightRAGKGBuilder(BaseKGBuilder):
20+
llm_client: BaseLLMClient = None
21+
max_loop: int = 3
22+
23+
async def extract(
24+
self, chunk: Chunk
25+
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
26+
"""
27+
Extract entities and relationships from a single chunk using the LLM client.
28+
:param chunk
29+
:return: (nodes_data, edges_data)
30+
"""
31+
chunk_id = chunk.id
32+
content = chunk.content
33+
34+
# step 1: language_detection
35+
language = "Chinese" if detect_if_chinese(content) else "English"
36+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
37+
38+
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
39+
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
40+
)
41+
42+
# step 2: initial glean
43+
final_result = await self.llm_client.generate_answer(hint_prompt)
44+
logger.debug("First extraction result: %s", final_result)
45+
46+
# step3: iterative refinement
47+
history = pack_history_conversations(hint_prompt, final_result)
48+
for loop_idx in range(self.max_loop):
49+
if_loop_result = await self.llm_client.generate_answer(
50+
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
51+
)
52+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
53+
if if_loop_result != "yes":
54+
break
55+
56+
glean_result = await self.llm_client.generate_answer(
57+
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
58+
)
59+
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
60+
61+
history += pack_history_conversations(
62+
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
63+
)
64+
final_result += glean_result
65+
66+
# step 4: parse the final result
67+
records = split_string_by_multi_markers(
68+
final_result,
69+
[
70+
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
71+
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
72+
],
73+
)
74+
75+
nodes = defaultdict(list)
76+
edges = defaultdict(list)
77+
78+
for record in records:
79+
match = re.search(r"\((.*)\)", record)
80+
if not match:
81+
continue
82+
inner = match.group(1)
83+
84+
attributes = split_string_by_multi_markers(
85+
inner, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
86+
)
87+
88+
entity = await handle_single_entity_extraction(attributes, chunk_id)
89+
if entity is not None:
90+
nodes[entity["entity_name"]].append(entity)
91+
continue
92+
93+
relation = await handle_single_relationship_extraction(attributes, chunk_id)
94+
if relation is not None:
95+
key = (relation["src_id"], relation["tgt_id"])
96+
edges[key].append(relation)
97+
98+
return dict(nodes), dict(edges)
99+
100+
async def merge_nodes(
101+
self,
102+
entity_name: str,
103+
node_data: Dict[str, List[dict]],
104+
kg_instance: BaseGraphStorage,
105+
) -> BaseGraphStorage:
106+
pass
107+
108+
async def merge_edges(
109+
self,
110+
edges_data: Dict[Tuple[str, str], List[dict]],
111+
kg_instance: BaseGraphStorage,
112+
) -> BaseGraphStorage:
113+
pass
114+
115+
# async def process_single_node(entity_name: str, node_data: list[dict]):
116+
# entity_types = []
117+
# source_ids = []
118+
# descriptions = []
119+
#
120+
# node = await kg_instance.get_node(entity_name)
121+
# if node is not None:
122+
# entity_types.append(node["entity_type"])
123+
# source_ids.extend(
124+
# split_string_by_multi_markers(node["source_id"], ["<SEP>"])
125+
# )
126+
# descriptions.append(node["description"])
127+
#
128+
# # 统计当前节点数据和已有节点数据的entity_type出现次数,取出现次数最多的entity_type
129+
# entity_type = sorted(
130+
# Counter([dp["entity_type"] for dp in node_data] + entity_types).items(),
131+
# key=lambda x: x[1],
132+
# reverse=True,
133+
# )[0][0]
134+
#
135+
# description = "<SEP>".join(
136+
# sorted(set([dp["description"] for dp in node_data] + descriptions))
137+
# )
138+
# description = await _handle_kg_summary(
139+
# entity_name, description, llm_client, tokenizer_instance
140+
# )
141+
#
142+
# source_id = "<SEP>".join(
143+
# set([dp["source_id"] for dp in node_data] + source_ids)
144+
# )
145+
#
146+
# node_data = {
147+
# "entity_type": entity_type,
148+
# "description": description,
149+
# "source_id": source_id,
150+
# }
151+
# await kg_instance.upsert_node(entity_name, node_data=node_data)
152+
# node_data["entity_name"] = entity_name
153+
# return node_data

graphgen/operators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from graphgen.operators.build_kg.extract_kg import extract_kg
1+
from graphgen.operators.build_kg.build_kg import build_kg
22
from graphgen.operators.generate.generate_cot import generate_cot
33
from graphgen.operators.search.search_all import search_all
44

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections import defaultdict
2+
from typing import List
3+
4+
import gradio as gr
5+
6+
from graphgen.bases.base_storage import BaseGraphStorage
7+
from graphgen.bases.datatypes import Chunk
8+
from graphgen.models import LightRAGKGBuilder, OpenAIClient, Tokenizer
9+
from graphgen.operators.build_kg.merge_kg import merge_edges, merge_nodes
10+
from graphgen.utils import run_concurrent
11+
12+
13+
async def build_kg(
14+
llm_client: OpenAIClient,
15+
kg_instance: BaseGraphStorage,
16+
tokenizer_instance: Tokenizer,
17+
chunks: List[Chunk],
18+
progress_bar: gr.Progress = None,
19+
):
20+
"""
21+
:param llm_client: Synthesizer LLM model to extract entities and relationships
22+
:param kg_instance
23+
:param tokenizer_instance
24+
:param chunks
25+
:param progress_bar: Gradio progress bar to show the progress of the extraction
26+
:return:
27+
"""
28+
29+
kg_builder = LightRAGKGBuilder(llm_client=llm_client, max_loop=3)
30+
31+
results = await run_concurrent(
32+
kg_builder.extract,
33+
chunks,
34+
desc="[2/4]Extracting entities and relationships from chunks",
35+
unit="chunk",
36+
progress_bar=progress_bar,
37+
)
38+
39+
nodes = defaultdict(list)
40+
edges = defaultdict(list)
41+
for n, e in results:
42+
for k, v in n.items():
43+
nodes[k].extend(v)
44+
for k, v in e.items():
45+
edges[tuple(sorted(k))].extend(v)
46+
47+
await merge_nodes(nodes, kg_instance, llm_client, tokenizer_instance)
48+
await merge_edges(edges, kg_instance, llm_client, tokenizer_instance)
49+
50+
return kg_instance

0 commit comments

Comments
 (0)