Skip to content

Commit 776cc20

Browse files
fix(webui): refine gradio progress_bar output
1 parent 24d9872 commit 776cc20

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

graphgen/graphgen.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass
88

99
from tqdm.asyncio import tqdm as tqdm_async
10+
import gradio as gr
1011

1112
from .models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
1213
from .models.storage.base_storage import StorageNameSpace
@@ -39,6 +40,9 @@ class GraphGen:
3940
# traverse strategy
4041
traverse_strategy: TraverseStrategy = TraverseStrategy()
4142

43+
# webui
44+
progress_bar: gr.Progress = None
45+
4246
def __post_init__(self):
4347
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
4448
self.working_dir, namespace="full_docs"
@@ -78,6 +82,9 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ
7882
logger.warning("All docs are already in the storage")
7983
return {}
8084
logger.info("[New Docs] inserting %d docs", len(new_docs))
85+
86+
cur_index = 1
87+
doc_number = len(new_docs)
8188
for doc_key, doc in tqdm_async(
8289
new_docs.items(), desc="Chunking documents", unit="doc"
8390
):
@@ -89,6 +96,13 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ
8996
self.chunk_overlap_size, self.chunk_size)
9097
}
9198
inserting_chunks.update(chunks)
99+
100+
if self.progress_bar is not None:
101+
self.progress_bar(
102+
cur_index / doc_number, f"Chunking {doc_key}"
103+
)
104+
cur_index += 1
105+
92106
_add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys()))
93107
inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys}
94108
elif data_type == "chunked":
@@ -141,7 +155,8 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str
141155
llm_client=self.synthesizer_llm_client,
142156
kg_instance=self.graph_storage,
143157
tokenizer_instance=self.tokenizer_instance,
144-
chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()]
158+
chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()],
159+
progress_bar = self.progress_bar,
145160
)
146161
if not _add_entities_and_relations:
147162
logger.warning("No entities or relations extracted")
@@ -199,16 +214,19 @@ async def async_traverse(self):
199214
self.tokenizer_instance,
200215
self.graph_storage,
201216
self.traverse_strategy,
202-
self.text_chunks_storage)
217+
self.text_chunks_storage,
218+
self.progress_bar)
203219
elif self.traverse_strategy.qa_form == "multi_hop":
204220
results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client,
205221
self.tokenizer_instance,
206222
self.graph_storage,
207223
self.traverse_strategy,
208-
self.text_chunks_storage)
224+
self.text_chunks_storage,
225+
self.progress_bar)
209226
else:
210227
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
211-
self.graph_storage, self.traverse_strategy, self.text_chunks_storage)
228+
self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
229+
self.progress_bar)
212230
await self.qa_storage.upsert(results)
213231
await self.qa_storage.index_done_callback()
214232

graphgen/operators/extract_kg.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import re
22
import asyncio
3-
43
from typing import List
54
from collections import defaultdict
5+
6+
import gradio as gr
67
from tqdm.asyncio import tqdm as tqdm_async
78
from graphgen.models import Chunk, OpenAIModel, Tokenizer
89
from graphgen.models.storage.base_storage import BaseGraphStorage
@@ -13,18 +14,21 @@
1314
from graphgen.operators.merge_kg import merge_nodes, merge_edges
1415

1516

17+
# pylint: disable=too-many-statements
1618
async def extract_kg(
1719
llm_client: OpenAIModel,
1820
kg_instance: BaseGraphStorage,
1921
tokenizer_instance: Tokenizer,
2022
chunks: List[Chunk],
23+
progress_bar: gr.Progress = None,
2124
max_concurrent: int = 1000
2225
):
2326
"""
2427
:param llm_client: Synthesizer LLM model to extract entities and relationships
2528
:param kg_instance
2629
:param tokenizer_instance
2730
:param chunks
31+
:param progress_bar: Gradio progress bar to show the progress of the extraction
2832
:param max_concurrent
2933
:return:
3034
"""
@@ -98,6 +102,7 @@ async def _process_single_content(chunk: Chunk, max_loop: int = 3):
98102
return dict(nodes), dict(edges)
99103

100104
results = []
105+
chunk_number = len(chunks)
101106
for result in tqdm_async(
102107
asyncio.as_completed([_process_single_content(c) for c in chunks]),
103108
total=len(chunks),
@@ -106,6 +111,8 @@ async def _process_single_content(chunk: Chunk, max_loop: int = 3):
106111
):
107112
try:
108113
results.append(await result)
114+
if progress_bar is not None:
115+
progress_bar(len(results) / chunk_number, desc="Extracting entities and relationships from chunks")
109116
except Exception as e: # pylint: disable=broad-except
110117
logger.error("Error occurred while extracting entities and relationships from chunks: %s", e)
111118

graphgen/operators/traverse_graph.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import gradio as gr
23

34
from tqdm.asyncio import tqdm as tqdm_async
45

@@ -167,6 +168,7 @@ async def traverse_graph_by_edge(
167168
graph_storage: NetworkXStorage,
168169
traverse_strategy: TraverseStrategy,
169170
text_chunks_storage: JsonKVStorage,
171+
progress_bar: gr.Progress = None,
170172
max_concurrent: int = 1000
171173
) -> dict:
172174
"""
@@ -177,6 +179,7 @@ async def traverse_graph_by_edge(
177179
:param graph_storage
178180
:param traverse_strategy
179181
:param text_chunks_storage
182+
:param progress_bar
180183
:param max_concurrent
181184
:return: question and answer
182185
"""
@@ -289,11 +292,13 @@ async def _process_single_batch(
289292

290293
for result in tqdm_async(asyncio.as_completed(
291294
[_process_single_batch(batch) for batch in processing_batches]
292-
), total=len(processing_batches), desc="Processing batches"):
295+
), total=len(processing_batches), desc="Generating QAs"):
293296
try:
294297
results.update(await result)
298+
if progress_bar is not None:
299+
progress_bar(len(results) / len(processing_batches), desc="Generating QAs")
295300
except Exception as e: # pylint: disable=broad-except
296-
logger.error("Error occurred while processing batches: %s", e)
301+
logger.error("Error occurred while generating QA: %s", e)
297302

298303
return results
299304

@@ -304,6 +309,7 @@ async def traverse_graph_atomically(
304309
graph_storage: NetworkXStorage,
305310
traverse_strategy: TraverseStrategy,
306311
text_chunks_storage: JsonKVStorage,
312+
progress_bar: gr.Progress = None,
307313
max_concurrent: int = 1000
308314
) -> dict:
309315
"""
@@ -314,6 +320,7 @@ async def traverse_graph_atomically(
314320
:param graph_storage
315321
:param traverse_strategy
316322
:param text_chunks_storage
323+
:param progress_bar
317324
:param max_concurrent
318325
:return: question and answer
319326
"""
@@ -391,12 +398,14 @@ async def _generate_question(
391398
for result in tqdm_async(
392399
asyncio.as_completed([_generate_question(task) for task in tasks]),
393400
total=len(tasks),
394-
desc="Generating questions"
401+
desc="Generating QAs"
395402
):
396403
try:
397404
results.update(await result)
405+
if progress_bar is not None:
406+
progress_bar(len(results) / len(tasks), desc="Generating QAs")
398407
except Exception as e: # pylint: disable=broad-except
399-
logger.error("Error occurred while generating questions: %s", e)
408+
logger.error("Error occurred while generating QA: %s", e)
400409
return results
401410

402411
async def traverse_graph_for_multi_hop(
@@ -405,6 +414,7 @@ async def traverse_graph_for_multi_hop(
405414
graph_storage: NetworkXStorage,
406415
traverse_strategy: TraverseStrategy,
407416
text_chunks_storage: JsonKVStorage,
417+
progress_bar: gr.Progress = None,
408418
max_concurrent: int = 1000
409419
) -> dict:
410420
"""
@@ -415,6 +425,7 @@ async def traverse_graph_for_multi_hop(
415425
:param graph_storage
416426
:param traverse_strategy
417427
:param text_chunks_storage
428+
:param progress_bar
418429
:param max_concurrent
419430
:return: question and answer
420431
"""
@@ -499,10 +510,12 @@ async def _process_single_batch(
499510
for result in tqdm_async(
500511
asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
501512
total=len(processing_batches),
502-
desc="Processing batches"
513+
desc="Generating QAs"
503514
):
504515
try:
505516
results.update(await result)
517+
if progress_bar is not None:
518+
progress_bar(len(results) / len(processing_batches), desc="Generating QAs")
506519
except Exception as e: # pylint: disable=broad-except
507-
logger.error("Error occurred while processing batches: %s", e)
520+
logger.error("Error occurred while generating QA: %s", e)
508521
return results

webui/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def run_graphgen(*arguments: list, progress=gr.Progress()):
100100
graph_gen.clear()
101101
progress(0.2, "Model Initialized")
102102

103+
graph_gen.progress_bar = progress
104+
103105
try:
104106
# Load input data
105107
file = config['input_file']

0 commit comments

Comments
 (0)