Skip to content

Commit a2f2b80

Browse files
Merge pull request #64 from open-sciencelab/update-gradio
Update gradio
2 parents 3e3b0ff + 3486e0d commit a2f2b80

File tree

12 files changed

+456
-234
lines changed

12 files changed

+456
-234
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ For any questions, please check [FAQ](https://github.com/open-sciencelab/GraphGe
105105
```bash
106106
python -m webui.app
107107
```
108+
109+
For hot-reload during development, run
110+
```bash
111+
PYTHONPATH=. gradio webui/app.py
112+
```
113+
108114

109115
![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84)
110116

README_ZH.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ GraphGen 首先根据源文本构建细粒度的知识图谱,然后利用期
104104
```bash
105105
python -m webui.app
106106
```
107+
108+
如果在开发过程中需要热重载,请运行
109+
110+
```bash
111+
PYTHONPATH=. gradio webui/app.py
112+
```
113+
107114

108115
![ui](https://github.com/user-attachments/assets/3024e9bc-5d45-45f8-a4e6-b57bd2350d84)
109116

graphgen/configs/aggregated_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ partition: # graph partition configuration
1616
max_units_per_community: 20 # max nodes and edges per community
1717
min_units_per_community: 5 # min nodes and edges per community
1818
max_tokens_per_community: 10240 # max tokens per community
19-
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
2020
generate:
2121
mode: aggregated # atomic, aggregated, multi_hop, cot
2222
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/configs/multi_hop_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ partition: # graph partition configuration
1616
max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
1717
min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
1818
max_tokens_per_community: 10240 # max tokens per community
19-
unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss
19+
unit_sampling: random # unit sampling strategy, support: random, max_loss, min_loss
2020
generate:
2121
mode: multi_hop # strategy for generating multi-hop QA pairs
2222
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/graphgen.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,10 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
237237

238238
# Step 2: generate QA pairs
239239
results = await generate_qas(
240-
self.synthesizer_llm_client, batches, generate_config
240+
self.synthesizer_llm_client,
241+
batches,
242+
generate_config,
243+
progress_bar=self.progress_bar,
241244
)
242245

243246
if not results:

graphgen/models/kg_builder/light_rag_kg_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def extract(
4242

4343
# step 2: initial glean
4444
final_result = await self.llm_client.generate_answer(hint_prompt)
45-
logger.debug("First extraction result: %s", final_result)
45+
logger.info("First extraction result: %s", final_result)
4646

4747
# step3: iterative refinement
4848
history = pack_history_conversations(hint_prompt, final_result)
@@ -57,7 +57,7 @@ async def extract(
5757
glean_result = await self.llm_client.generate_answer(
5858
text=KG_EXTRACTION_PROMPT[language]["CONTINUE"], history=history
5959
)
60-
logger.debug("Loop %s glean: %s", loop_idx + 1, glean_result)
60+
logger.info("Loop %s glean: %s", loop_idx + 1, glean_result)
6161

6262
history += pack_history_conversations(
6363
KG_EXTRACTION_PROMPT[language]["CONTINUE"], glean_result

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
class ECEPartitioner(BFSPartitioner):
1818
"""
1919
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
20-
We calculate ECE for edges in KG (represented as 'comprehension loss')
21-
and group edges with similar ECE values into the same community.
20+
We calculate ECE for units in KG (represented as 'comprehension loss')
21+
and group units with similar ECE values into the same community.
2222
1. Select a sampling strategy.
2323
2. Choose a unit based on the sampling strategy.
2424
2. Expand the community using BFS.

graphgen/operators/generate/generate_qas.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ async def generate_qas(
1818
]
1919
],
2020
generation_config: dict,
21+
progress_bar=None,
2122
) -> list[dict[str, Any]]:
2223
"""
2324
Generate question-answer pairs based on nodes and edges.
2425
:param llm_client: LLM client
2526
:param batches
2627
:param generation_config
28+
:param progress_bar
2729
:return: QA pairs
2830
"""
2931
mode = generation_config["mode"]
@@ -45,6 +47,7 @@ async def generate_qas(
4547
batches,
4648
desc="[4/4]Generating QAs",
4749
unit="batch",
50+
progress_bar=progress_bar,
4851
)
4952

5053
# format

graphgen/utils/run_concurrent.py

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,77 @@
1010
R = TypeVar("R")
1111

1212

13+
# async def run_concurrent(
14+
# coro_fn: Callable[[T], Awaitable[R]],
15+
# items: List[T],
16+
# *,
17+
# desc: str = "processing",
18+
# unit: str = "item",
19+
# progress_bar: Optional[gr.Progress] = None,
20+
# ) -> List[R]:
21+
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
22+
#
23+
# results = []
24+
# async for future in tqdm_async(
25+
# tasks, desc=desc, unit=unit
26+
# ):
27+
# try:
28+
# result = await future
29+
# results.append(result)
30+
# except Exception as e: # pylint: disable=broad-except
31+
# logger.exception("Task failed: %s", e)
32+
#
33+
# if progress_bar is not None:
34+
# progress_bar((len(results)) / len(items), desc=desc)
35+
#
36+
# if progress_bar is not None:
37+
# progress_bar(1.0, desc=desc)
38+
# return results
39+
40+
# results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
41+
#
42+
# ok_results = []
43+
# for idx, res in enumerate(results):
44+
# if isinstance(res, Exception):
45+
# logger.exception("Task failed: %s", res)
46+
# if progress_bar:
47+
# progress_bar((idx + 1) / len(items), desc=desc)
48+
# continue
49+
# ok_results.append(res)
50+
# if progress_bar:
51+
# progress_bar((idx + 1) / len(items), desc=desc)
52+
#
53+
# if progress_bar:
54+
# progress_bar(1.0, desc=desc)
55+
# return ok_results
56+
57+
# async def run_concurrent(
58+
# coro_fn: Callable[[T], Awaitable[R]],
59+
# items: List[T],
60+
# *,
61+
# desc: str = "processing",
62+
# unit: str = "item",
63+
# progress_bar: Optional[gr.Progress] = None,
64+
# ) -> List[R]:
65+
# tasks = [asyncio.create_task(coro_fn(it)) for it in items]
66+
#
67+
# results = []
68+
# # 使用同步方式更新进度条,避免异步冲突
69+
# for i, task in enumerate(asyncio.as_completed(tasks)):
70+
# try:
71+
# result = await task
72+
# results.append(result)
73+
# # 同步更新进度条
74+
# if progress_bar is not None:
75+
# # 在同步上下文中更新进度
76+
# progress_bar((i + 1) / len(items), desc=desc)
77+
# except Exception as e:
78+
# logger.exception("Task failed: %s", e)
79+
# results.append(e)
80+
#
81+
# return results
82+
83+
1384
async def run_concurrent(
1485
coro_fn: Callable[[T], Awaitable[R]],
1586
items: List[T],
@@ -20,19 +91,36 @@ async def run_concurrent(
2091
) -> List[R]:
2192
tasks = [asyncio.create_task(coro_fn(it)) for it in items]
2293

23-
results = await tqdm_async.gather(*tasks, desc=desc, unit=unit)
24-
25-
ok_results = []
26-
for idx, res in enumerate(results):
27-
if isinstance(res, Exception):
28-
logger.exception("Task failed: %s", res)
29-
if progress_bar:
30-
progress_bar((idx + 1) / len(items), desc=desc)
31-
continue
32-
ok_results.append(res)
33-
if progress_bar:
34-
progress_bar((idx + 1) / len(items), desc=desc)
35-
36-
if progress_bar:
37-
progress_bar(1.0, desc=desc)
38-
return ok_results
94+
completed_count = 0
95+
results = []
96+
97+
pbar = tqdm_async(total=len(items), desc=desc, unit=unit)
98+
99+
if progress_bar is not None:
100+
progress_bar(0.0, desc=f"{desc} (0/{len(items)})")
101+
102+
for future in asyncio.as_completed(tasks):
103+
try:
104+
result = await future
105+
results.append(result)
106+
except Exception as e: # pylint: disable=broad-except
107+
logger.exception("Task failed: %s", e)
108+
# even if failed, record it to keep results consistent with tasks
109+
results.append(e)
110+
111+
completed_count += 1
112+
pbar.update(1)
113+
114+
if progress_bar is not None:
115+
progress = completed_count / len(items)
116+
progress_bar(progress, desc=f"{desc} ({completed_count}/{len(items)})")
117+
118+
pbar.close()
119+
120+
if progress_bar is not None:
121+
progress_bar(1.0, desc=f"{desc} (completed)")
122+
123+
# filter out exceptions
124+
results = [res for res in results if not isinstance(res, Exception)]
125+
126+
return results

0 commit comments

Comments
 (0)