Skip to content

Commit cb1a63b

Browse files
fix: fix progress_bar in run_concurrent
1 parent f1f0b2d commit cb1a63b

File tree

4 files changed

+113
-19
lines changed

4 files changed

+113
-19
lines changed

graphgen/graphgen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
237237

238238
# Step 2: generate QA pairs
239239
results = await generate_qas(
240-
self.synthesizer_llm_client, batches, generate_config
240+
self.synthesizer_llm_client,
241+
batches,
242+
generate_config,
243+
progress_bar=self.progress_bar,
241244
)
242245

243246
if not results:

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def extract(
4242

4343
# step 2: initial glean
4444
final_result = await self.llm_client.generate_answer(hint_prompt)
45-
logger.debug("First extraction result: %s", final_result)
45+
logger.info("First extraction result: %s", final_result)
4646

4747
# step3: iterative refinement
4848
history = pack_history_conversations(hint_prompt, final_result)
@@ -57,7 +57,7 @@ async def extract(
5757
glean_result = await self.llm_client.generate_answer(
5858
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
5959
)
60-
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
60+
logger.info("Loop %s glean: %s", loop_idx + 1, glean_result)
6161

6262
history += pack_history_conversations(
6363
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result

graphgen/operators/generate/generate_qas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ async def generate_qas(
1818
]
1919
],
2020
generation_config: dict,
21+
progress_bar=None,
2122
) -> list[dict[str, Any]]:
2223
"""
2324
Generate question-answer pairs based on nodes and edges.
2425
:param llm_client: LLM client
2526
:param batches
2627
:param generation_config
28+
:param progress_bar
2729
:return: QA pairs
2830
"""
2931
mode = generation_config["mode"]
@@ -45,6 +47,7 @@ async def generate_qas(
4547
batches,
4648
desc="[4/4]Generating QAs",
4749
unit="batch",
50+
progress_bar=progress_bar,
4851
)
4952

5053
# format

graphgen/utils/run_concurrent.py

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,77 @@
1010
R = TypeVar("R")
1111

1212

13+
# async def run_concurrent(
14+
# coro_fn: Callable[[T], Awaitable[R]],
15+
# items: List[T],
16+
# *,
17+
# desc: str = "processing",
18+
# unit: str = "item",
19+
# progress_bar: Optional[gr.Progress] = None,
20+
# ) -> List[R]:
21+
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
22+
#
23+
# results = []
24+
# async for future in tqdm_async(
25+
# tasks, desc=desc, unit=unit
26+
# ):
27+
# try:
28+
# result = await future
29+
# results.append(result)
30+
# except Exception as e: # pylint: disable=broad-except
31+
# logger.exception("Task failed: %s", e)
32+
#
33+
# if progress_bar is not None:
34+
# progress_bar((len(results)) / len(items), desc=desc)
35+
#
36+
# if progress_bar is not None:
37+
# progress_bar(1.0, desc=desc)
38+
# return results
39+
40+
# results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
41+
#
42+
# ok_results = []
43+
# for idx, res in enumerate(results):
44+
# if isinstance(res, Exception):
45+
# logger.exception("Task failed: %s", res)
46+
# if progress_bar:
47+
# progress_bar((idx + 1) / len(items), desc=desc)
48+
# continue
49+
# ok_results.append(res)
50+
# if progress_bar:
51+
# progress_bar((idx + 1) / len(items), desc=desc)
52+
#
53+
# if progress_bar:
54+
# progress_bar(1.0, desc=desc)
55+
# return ok_results
56+
57+
# async def run_concurrent(
58+
# coro_fn: Callable[[T], Awaitable[R]],
59+
# items: List[T],
60+
# *,
61+
# desc: str = "processing",
62+
# unit: str = "item",
63+
# progress_bar: Optional[gr.Progress] = None,
64+
# ) -> List[R]:
65+
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
66+
#
67+
# results = []
68+
# # 使用同步方式更新进度条,避免异步冲突
69+
# for i, task in enumerate(asyncio.as_completed(tasks)):
70+
# try:
71+
# result = await task
72+
# results.append(result)
73+
# # 同步更新进度条
74+
# if progress_bar is not None:
75+
# # 在同步上下文中更新进度
76+
# progress_bar((i + 1) / len(items), desc=desc)
77+
# except Exception as e:
78+
# logger.exception("Task failed: %s", e)
79+
# results.append(e)
80+
#
81+
# return results
82+
83+
1384
async def run_concurrent(
1485
coro_fn: Callable[[T], Awaitable[R]],
1586
items: List[T],
@@ -20,19 +91,36 @@ async def run_concurrent(
2091
) -> List[R]:
2192
tasks = [asyncio.create_task(coro_fn(it)) for it in items]
2293

23-
results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
24-
25-
ok_results = []
26-
for idx, res in enumerate(results):
27-
if isinstance(res, Exception):
28-
logger.exception("Task failed: %s", res)
29-
if progress_bar:
30-
progress_bar((idx + 1) / len(items), desc=desc)
31-
continue
32-
ok_results.append(res)
33-
if progress_bar:
34-
progress_bar((idx + 1) / len(items), desc=desc)
35-
36-
if progress_bar:
37-
progress_bar(1.0, desc=desc)
38-
return ok_results
94+
completed_count = 0
95+
results = []
96+
97+
pbar = tqdm_async(total=len(items), desc=desc, unit=unit)
98+
99+
if progress_bar is not None:
100+
progress_bar(0.0, desc=f"{desc} (0/{len(items)})")
101+
102+
for future in asyncio.as_completed(tasks):
103+
try:
104+
result = await future
105+
results.append(result)
106+
except Exception as e: # pylint: disable=broad-except
107+
logger.exception("Task failed: %s", e)
108+
# even if failed, record it to keep results consistent with tasks
109+
results.append(e)
110+
111+
completed_count += 1
112+
pbar.update(1)
113+
114+
if progress_bar is not None:
115+
progress = completed_count / len(items)
116+
progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})")
117+
118+
pbar.close()
119+
120+
if progress_bar is not None:
121+
progress_bar(1.0, desc=f"{desc} (completed)")
122+
123+
# filter out exceptions
124+
results = [res for res in results if not isinstance(res, Exception)]
125+
126+
return results

0 commit comments

Comments
 (0)