1- import asyncio
21import re
32from collections import defaultdict
43from typing import List
54
65import gradio as gr
7- from tqdm .asyncio import tqdm as tqdm_async
86
97from graphgen .bases .base_storage import BaseGraphStorage
108from graphgen .bases .datatypes import Chunk
1715 handle_single_relationship_extraction ,
1816 logger ,
1917 pack_history_conversations ,
18+ run_concurrent ,
2019 split_string_by_multi_markers ,
2120)
2221
@@ -28,115 +27,91 @@ async def extract_kg(
2827 tokenizer_instance : Tokenizer ,
2928 chunks : List [Chunk ],
3029 progress_bar : gr .Progress = None ,
31- max_concurrent : int = 1000 ,
3230):
3331 """
3432 :param llm_client: Synthesizer LLM model to extract entities and relationships
3533 :param kg_instance
3634 :param tokenizer_instance
3735 :param chunks
3836 :param progress_bar: Gradio progress bar to show the progress of the extraction
39- :param max_concurrent
4037 :return:
4138 """
4239
43- semaphore = asyncio .Semaphore (max_concurrent )
44-
4540 async def _process_single_content (chunk : Chunk , max_loop : int = 3 ):
46- async with semaphore :
47- chunk_id = chunk .id
48- content = chunk .content
49- if detect_if_chinese (content ):
50- language = "Chinese"
51- else :
52- language = "English"
53- KG_EXTRACTION_PROMPT ["FORMAT" ]["language" ] = language
54-
55- hint_prompt = KG_EXTRACTION_PROMPT [language ]["TEMPLATE" ].format (
56- ** KG_EXTRACTION_PROMPT ["FORMAT" ], input_text = content
41+ chunk_id = chunk .id
42+ content = chunk .content
43+ if detect_if_chinese (content ):
44+ language = "Chinese"
45+ else :
46+ language = "English"
47+ KG_EXTRACTION_PROMPT ["FORMAT" ]["language" ] = language
48+
49+ hint_prompt = KG_EXTRACTION_PROMPT [language ]["TEMPLATE" ].format (
50+ ** KG_EXTRACTION_PROMPT ["FORMAT" ], input_text = content
51+ )
52+
53+ final_result = await llm_client .generate_answer (hint_prompt )
54+ logger .info ("First result: %s" , final_result )
55+
56+ history = pack_history_conversations (hint_prompt , final_result )
57+ for loop_index in range (max_loop ):
58+ if_loop_result = await llm_client .generate_answer (
59+ text = KG_EXTRACTION_PROMPT [language ]["IF_LOOP" ], history = history
60+ )
61+ if_loop_result = if_loop_result .strip ().strip ('"' ).strip ("'" ).lower ()
62+ if if_loop_result != "yes" :
63+ break
64+
65+ glean_result = await llm_client .generate_answer (
66+ text = KG_EXTRACTION_PROMPT [language ]["CONTINUE" ], history = history
5767 )
68+ logger .info ("Loop %s glean: %s" , loop_index , glean_result )
5869
59- final_result = await llm_client .generate_answer (hint_prompt )
60- logger .info ("First result: %s" , final_result )
61-
62- history = pack_history_conversations (hint_prompt , final_result )
63- for loop_index in range (max_loop ):
64- if_loop_result = await llm_client .generate_answer (
65- text = KG_EXTRACTION_PROMPT [language ]["IF_LOOP" ], history = history
66- )
67- if_loop_result = if_loop_result .strip ().strip ('"' ).strip ("'" ).lower ()
68- if if_loop_result != "yes" :
69- break
70-
71- glean_result = await llm_client .generate_answer (
72- text = KG_EXTRACTION_PROMPT [language ]["CONTINUE" ], history = history
73- )
74- logger .info ("Loop %s glean: %s" , loop_index , glean_result )
75-
76- history += pack_history_conversations (
77- KG_EXTRACTION_PROMPT [language ]["CONTINUE" ], glean_result
78- )
79- final_result += glean_result
80- if loop_index == max_loop - 1 :
81- break
82-
83- records = split_string_by_multi_markers (
84- final_result ,
85- [
86- KG_EXTRACTION_PROMPT ["FORMAT" ]["record_delimiter" ],
87- KG_EXTRACTION_PROMPT ["FORMAT" ]["completion_delimiter" ],
88- ],
70+ history += pack_history_conversations (
71+ KG_EXTRACTION_PROMPT [language ]["CONTINUE" ], glean_result
72+ )
73+ final_result += glean_result
74+ if loop_index == max_loop - 1 :
75+ break
76+
77+ records = split_string_by_multi_markers (
78+ final_result ,
79+ [
80+ KG_EXTRACTION_PROMPT ["FORMAT" ]["record_delimiter" ],
81+ KG_EXTRACTION_PROMPT ["FORMAT" ]["completion_delimiter" ],
82+ ],
83+ )
84+
85+ nodes = defaultdict (list )
86+ edges = defaultdict (list )
87+
88+ for record in records :
89+ record = re .search (r"\((.*)\)" , record )
90+ if record is None :
91+ continue
92+ record = record .group (1 ) # 提取括号内的内容
93+ record_attributes = split_string_by_multi_markers (
94+ record , [KG_EXTRACTION_PROMPT ["FORMAT" ]["tuple_delimiter" ]]
8995 )
9096
91- nodes = defaultdict (list )
92- edges = defaultdict (list )
93-
94- for record in records :
95- record = re .search (r"\((.*)\)" , record )
96- if record is None :
97- continue
98- record = record .group (1 ) # 提取括号内的内容
99- record_attributes = split_string_by_multi_markers (
100- record , [KG_EXTRACTION_PROMPT ["FORMAT" ]["tuple_delimiter" ]]
101- )
102-
103- entity = await handle_single_entity_extraction (
104- record_attributes , chunk_id
105- )
106- if entity is not None :
107- nodes [entity ["entity_name" ]].append (entity )
108- continue
109- relation = await handle_single_relationship_extraction (
110- record_attributes , chunk_id
111- )
112- if relation is not None :
113- edges [(relation ["src_id" ], relation ["tgt_id" ])].append (relation )
114- return dict (nodes ), dict (edges )
115-
116- results = []
117- chunk_number = len (chunks )
118- async for result in tqdm_async (
119- asyncio .as_completed ([_process_single_content (c ) for c in chunks ]),
120- total = len (chunks ),
97+ entity = await handle_single_entity_extraction (record_attributes , chunk_id )
98+ if entity is not None :
99+ nodes [entity ["entity_name" ]].append (entity )
100+ continue
101+ relation = await handle_single_relationship_extraction (
102+ record_attributes , chunk_id
103+ )
104+ if relation is not None :
105+ edges [(relation ["src_id" ], relation ["tgt_id" ])].append (relation )
106+ return dict (nodes ), dict (edges )
107+
108+ results = await run_concurrent (
109+ _process_single_content ,
110+ chunks ,
121111 desc = "[2/4]Extracting entities and relationships from chunks" ,
122112 unit = "chunk" ,
123- ):
124- try :
125- if progress_bar is not None :
126- progress_bar (
127- len (results ) / chunk_number ,
128- desc = "[3/4]Extracting entities and relationships from chunks" ,
129- )
130- results .append (await result )
131- if progress_bar is not None and len (results ) == chunk_number :
132- progress_bar (
133- 1 , desc = "[3/4]Extracting entities and relationships from chunks"
134- )
135- except Exception as e : # pylint: disable=broad-except
136- logger .error (
137- "Error occurred while extracting entities and relationships from chunks: %s" ,
138- e ,
139- )
113+ progress_bar = progress_bar ,
114+ )
140115
141116 nodes = defaultdict (list )
142117 edges = defaultdict (list )
0 commit comments