Skip to content

Commit 8fd34b2

Browse files
refactor: abstract run_concurrent & delete semaphore
1 parent 051dc77 commit 8fd34b2

File tree

4 files changed

+110
-96
lines changed

4 files changed

+110
-96
lines changed

graphgen/graphgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ async def insert(self):
136136

137137
inserting_chunks = await chunk_documents(
138138
new_docs,
139-
self.chunk_size,
140-
self.chunk_overlap,
139+
self.config["split"]["chunk_size"],
140+
self.config["split"]["chunk_overlap"],
141141
self.tokenizer_instance,
142142
self.progress_bar,
143143
)

graphgen/operators/build_kg/extract_kg.py

Lines changed: 69 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
import asyncio
21
import re
32
from collections import defaultdict
43
from typing import List
54

65
import gradio as gr
7-
from tqdm.asyncio import tqdm as tqdm_async
86

97
from graphgen.bases.base_storage import BaseGraphStorage
108
from graphgen.bases.datatypes import Chunk
@@ -17,6 +15,7 @@
1715
handle_single_relationship_extraction,
1816
logger,
1917
pack_history_conversations,
18+
run_concurrent,
2019
split_string_by_multi_markers,
2120
)
2221

@@ -28,115 +27,91 @@ async def extract_kg(
2827
tokenizer_instance: Tokenizer,
2928
chunks: List[Chunk],
3029
progress_bar: gr.Progress = None,
31-
max_concurrent: int = 1000,
3230
):
3331
"""
3432
:param llm_client: Synthesizer LLM model to extract entities and relationships
3533
:param kg_instance
3634
:param tokenizer_instance
3735
:param chunks
3836
:param progress_bar: Gradio progress bar to show the progress of the extraction
39-
:param max_concurrent
4037
:return:
4138
"""
4239

43-
semaphore = asyncio.Semaphore(max_concurrent)
44-
4540
async def _process_single_content(chunk: Chunk, max_loop: int = 3):
46-
async with semaphore:
47-
chunk_id = chunk.id
48-
content = chunk.content
49-
if detect_if_chinese(content):
50-
language = "Chinese"
51-
else:
52-
language = "English"
53-
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
54-
55-
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
56-
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
41+
chunk_id = chunk.id
42+
content = chunk.content
43+
if detect_if_chinese(content):
44+
language = "Chinese"
45+
else:
46+
language = "English"
47+
KG_EXTRACTION_PROMPT["FORMAT"]["language"] = language
48+
49+
hint_prompt = KG_EXTRACTION_PROMPT[language]["TEMPLATE"].format(
50+
**KG_EXTRACTION_PROMPT["FORMAT"], input_text=content
51+
)
52+
53+
final_result = await llm_client.generate_answer(hint_prompt)
54+
logger.info("First result: %s", final_result)
55+
56+
history = pack_history_conversations(hint_prompt, final_result)
57+
for loop_index in range(max_loop):
58+
if_loop_result = await llm_client.generate_answer(
59+
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
60+
)
61+
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
62+
if if_loop_result != "yes":
63+
break
64+
65+
glean_result = await llm_client.generate_answer(
66+
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
5767
)
68+
logger.info("Loop %s glean: %s", loop_index, glean_result)
5869

59-
final_result = await llm_client.generate_answer(hint_prompt)
60-
logger.info("First result: %s", final_result)
61-
62-
history = pack_history_conversations(hint_prompt, final_result)
63-
for loop_index in range(max_loop):
64-
if_loop_result = await llm_client.generate_answer(
65-
text=KG_EXTRACTION_PROMPT[language]["IF_LOOP"], history=history
66-
)
67-
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
68-
if if_loop_result != "yes":
69-
break
70-
71-
glean_result = await llm_client.generate_answer(
72-
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
73-
)
74-
logger.info("Loop %s glean: %s", loop_index, glean_result)
75-
76-
history += pack_history_conversations(
77-
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
78-
)
79-
final_result += glean_result
80-
if loop_index == max_loop - 1:
81-
break
82-
83-
records = split_string_by_multi_markers(
84-
final_result,
85-
[
86-
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
87-
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
88-
],
70+
history += pack_history_conversations(
71+
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result
72+
)
73+
final_result += glean_result
74+
if loop_index == max_loop - 1:
75+
break
76+
77+
records = split_string_by_multi_markers(
78+
final_result,
79+
[
80+
KG_EXTRACTION_PROMPT["FORMAT"]["record_delimiter"],
81+
KG_EXTRACTION_PROMPT["FORMAT"]["completion_delimiter"],
82+
],
83+
)
84+
85+
nodes = defaultdict(list)
86+
edges = defaultdict(list)
87+
88+
for record in records:
89+
record = re.search(r"\((.*)\)", record)
90+
if record is None:
91+
continue
92+
record = record.group(1) # 提取括号内的内容
93+
record_attributes = split_string_by_multi_markers(
94+
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
8995
)
9096

91-
nodes = defaultdict(list)
92-
edges = defaultdict(list)
93-
94-
for record in records:
95-
record = re.search(r"\((.*)\)", record)
96-
if record is None:
97-
continue
98-
record = record.group(1) # 提取括号内的内容
99-
record_attributes = split_string_by_multi_markers(
100-
record, [KG_EXTRACTION_PROMPT["FORMAT"]["tuple_delimiter"]]
101-
)
102-
103-
entity = await handle_single_entity_extraction(
104-
record_attributes, chunk_id
105-
)
106-
if entity is not None:
107-
nodes[entity["entity_name"]].append(entity)
108-
continue
109-
relation = await handle_single_relationship_extraction(
110-
record_attributes, chunk_id
111-
)
112-
if relation is not None:
113-
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
114-
return dict(nodes), dict(edges)
115-
116-
results = []
117-
chunk_number = len(chunks)
118-
async for result in tqdm_async(
119-
asyncio.as_completed([_process_single_content(c) for c in chunks]),
120-
total=len(chunks),
97+
entity = await handle_single_entity_extraction(record_attributes, chunk_id)
98+
if entity is not None:
99+
nodes[entity["entity_name"]].append(entity)
100+
continue
101+
relation = await handle_single_relationship_extraction(
102+
record_attributes, chunk_id
103+
)
104+
if relation is not None:
105+
edges[(relation["src_id"], relation["tgt_id"])].append(relation)
106+
return dict(nodes), dict(edges)
107+
108+
results = await run_concurrent(
109+
_process_single_content,
110+
chunks,
121111
desc="[2/4]Extracting entities and relationships from chunks",
122112
unit="chunk",
123-
):
124-
try:
125-
if progress_bar is not None:
126-
progress_bar(
127-
len(results) / chunk_number,
128-
desc="[3/4]Extracting entities and relationships from chunks",
129-
)
130-
results.append(await result)
131-
if progress_bar is not None and len(results) == chunk_number:
132-
progress_bar(
133-
1, desc="[3/4]Extracting entities and relationships from chunks"
134-
)
135-
except Exception as e: # pylint: disable=broad-except
136-
logger.error(
137-
"Error occurred while extracting entities and relationships from chunks: %s",
138-
e,
139-
)
113+
progress_bar=progress_bar,
114+
)
140115

141116
nodes = defaultdict(list)
142117
edges = defaultdict(list)

graphgen/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@
1313
from .help_nltk import NLTKHelper
1414
from .log import logger, parse_log, set_logger
1515
from .loop import create_event_loop
16+
from .run_concurrent import run_concurrent
1617
from .wrap import async_to_sync_method

graphgen/utils/run_concurrent.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
from typing import Awaitable, Callable, List, Optional, TypeVar
3+
4+
import gradio as gr
5+
from tqdm.asyncio import tqdm as tqdm_async
6+
7+
from graphgen.utils.log import logger
8+
9+
T = TypeVar("T")
10+
R = TypeVar("R")
11+
12+
13+
async def run_concurrent(
14+
coro_fn: Callable[[T], Awaitable[R]],
15+
items: List[T],
16+
*,
17+
desc: str = "processing",
18+
unit: str = "item",
19+
progress_bar: Optional[gr.Progress] = None,
20+
) -> List[R]:
21+
tasks = [asyncio.create_task(coro_fn(it)) for it in items]
22+
23+
results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
24+
25+
ok_results = []
26+
for idx, res in enumerate(results):
27+
if isinstance(res, Exception):
28+
logger.exception("Task failed: %s", res)
29+
if progress_bar:
30+
progress_bar((idx + 1) / len(items), desc=desc)
31+
continue
32+
ok_results.append(res)
33+
if progress_bar:
34+
progress_bar((idx + 1) / len(items), desc=desc)
35+
36+
if progress_bar:
37+
progress_bar(1.0, desc=desc)
38+
return ok_results

0 commit comments

Comments
 (0)