Skip to content

Commit c9435d7

Browse files
refactor: refactor quiz&judge to ray actors
1 parent bc07222 commit c9435d7

File tree

6 files changed

+122
-82
lines changed

6 files changed

+122
-82
lines changed
File renamed without changes.

graphgen/operators/judge/judge_service.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,76 @@
11
import math
22

3-
import gradio as gr
4-
5-
from graphgen.bases import BaseLLMWrapper
6-
from graphgen.models import JsonKVStorage, NetworkXStorage
7-
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8-
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
9-
10-
11-
import math
12-
from collections.abc import Iterable
13-
143
import pandas as pd
154

16-
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
5+
from graphgen.bases import BaseGraphStorage, BaseLLMWrapper
176
from graphgen.common import init_llm, init_storage
18-
from graphgen.models import NetworkXStorage, JsonKVStorage
197
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
208
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
219

2210

2311
class JudgeService:
2412
"""Service for judging graph edges and nodes using a trainee LLM."""
13+
2514
def __init__(self, working_dir: str = "cache"):
2615
self.llm_client: BaseLLMWrapper = init_llm("trainee")
16+
self.graph_storage: BaseGraphStorage = init_storage(
17+
backend="networkx",
18+
working_dir=working_dir,
19+
namespace="graph",
20+
)
2721

2822
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
23+
items = batch.to_dict(orient="records")
24+
self.graph_storage.reload()
25+
self.judge(items)
2926
return pd.DataFrame([{"status": "judging_completed"}])
3027

31-
def judge(self) -> Iterable[pd.DataFrame]:
32-
"""
33-
Judge the statements in the graph storage
28+
async def _process_single_judge(self, item: dict) -> dict:
29+
description = item["description"]
30+
try:
31+
judgement = await self.llm_client.generate_topk_per_token(
32+
STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
33+
)
34+
top_candidates = judgement[0].top_candidates
35+
gt = item.get("ground_truth", "yes")
36+
loss = yes_no_loss_entropy([top_candidates], [gt])
37+
logger.debug("Description: %s Loss: %s", description, loss)
38+
item["loss"] = loss
39+
except Exception as e: # pylint: disable=broad-except
40+
logger.error("Error in judging description: %s", e)
41+
logger.info("Use default loss 0.1")
42+
item["loss"] = -math.log(0.1)
43+
return item
3444

35-
:param re_judge: re-judge the relations
36-
:return:
45+
def judge(self, items: list[dict]) -> None:
46+
"""
47+
Judge the description in the item and compute the loss.
3748
"""
38-
return
49+
results = run_concurrent(
50+
self._process_single_judge,
51+
items,
52+
desc="Judging descriptions",
53+
unit="description",
54+
)
3955

56+
# Update the graph storage with the computed losses
57+
for item in results:
58+
print(item)
59+
node_id = item.get("node_id")
60+
edge_source = item.get("edge_source")
61+
edge_target = item.get("edge_target")
62+
loss = item["loss"]
63+
if node_id is not None:
64+
node_data = self.graph_storage.get_node(node_id)
65+
if node_data is not None:
66+
node_data["loss"] = loss
67+
self.graph_storage.update_node(node_id, node_data)
68+
elif edge_source is not None and edge_target is not None:
69+
edge_data = self.graph_storage.get_edge(edge_source, edge_target)
70+
if edge_data is not None:
71+
edge_data["loss"] = loss
72+
self.graph_storage.update_edge(edge_source, edge_target, edge_data)
73+
self.graph_storage.index_done_callback()
4074

4175

4276
# async def judge_statement( # pylint: disable=too-many-statements

graphgen/operators/partition/partition_kg.py renamed to graphgen/operators/partition/partition_service.py

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010
)
1111
from graphgen.utils import logger
1212

13-
from .pre_tokenize import pre_tokenize
1413

15-
16-
async def partition_kg(
14+
def partition_kg(
1715
kg_instance: BaseGraphStorage,
1816
chunk_storage: BaseKVStorage,
1917
tokenizer: Any = BaseTokenizer,
@@ -60,7 +58,7 @@ async def partition_kg(
6058
return batches
6159

6260

63-
async def attach_additional_data_to_node(
61+
def attach_additional_data_to_node(
6462
batches: list[
6563
tuple[
6664
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
@@ -112,3 +110,61 @@ async def _attach_by_type(
112110
# We'll use the first image chunk found for this node.
113111
node_data["images"] = image_chunks[0]
114112
logger.debug("Attached image data to node %s", node_id)
113+
114+
115+
import asyncio
116+
from typing import List, Tuple
117+
118+
import gradio as gr
119+
120+
from graphgen.bases import BaseGraphStorage, BaseTokenizer
121+
from graphgen.utils import run_concurrent
122+
123+
124+
async def pre_tokenize(
125+
graph_storage: BaseGraphStorage,
126+
tokenizer: BaseTokenizer,
127+
edges: List[Tuple],
128+
nodes: List[Tuple],
129+
progress_bar: gr.Progress = None,
130+
max_concurrent: int = 1000,
131+
) -> Tuple[List, List]:
132+
"""为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
133+
sem = asyncio.Semaphore(max_concurrent)
134+
135+
async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
136+
async with sem:
137+
data = obj[1] if is_node else obj[2]
138+
if "length" not in data:
139+
loop = asyncio.get_event_loop()
140+
data["length"] = len(
141+
await loop.run_in_executor(
142+
None, tokenizer.encode, data["description"]
143+
)
144+
)
145+
if is_node:
146+
graph_storage.update_node(obj[0], obj[1])
147+
else:
148+
graph_storage.update_edge(obj[0], obj[1], obj[2])
149+
return obj
150+
151+
new_edges, new_nodes = await asyncio.gather(
152+
run_concurrent(
153+
lambda e: _patch_and_write(e, is_node=False),
154+
edges,
155+
desc="Pre-tokenizing edges",
156+
unit="edge",
157+
progress_bar=progress_bar,
158+
),
159+
run_concurrent(
160+
lambda n: _patch_and_write(n, is_node=True),
161+
nodes,
162+
desc="Pre-tokenizing nodes",
163+
unit="node",
164+
progress_bar=progress_bar,
165+
),
166+
)
167+
168+
graph_storage.index_done_callback()
169+
return new_edges, new_nodes
170+

graphgen/operators/partition/pre_tokenize.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

graphgen/operators/quiz/quiz_service.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@
55
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
66
from graphgen.common import init_llm, init_storage
77
from graphgen.models import QuizGenerator
8-
from graphgen.utils import compute_content_hash, run_concurrent, logger
8+
from graphgen.utils import compute_content_hash, logger, run_concurrent
99

1010

1111
class QuizService:
12-
def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrency_limit: int = 200):
12+
def __init__(
13+
self,
14+
working_dir: str = "cache",
15+
quiz_samples: int = 1,
16+
concurrency_limit: int = 200,
17+
):
1318
self.quiz_samples = quiz_samples
1419
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
1520
self.graph_storage: BaseGraphStorage = init_storage(
@@ -20,7 +25,6 @@ def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrenc
2025
backend="json_kv", working_dir=working_dir, namespace="quiz"
2126
)
2227
self.generator = QuizGenerator(self.llm_client)
23-
2428
self.concurrency_limit = concurrency_limit
2529

2630
def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
@@ -80,7 +84,6 @@ def quiz(self) -> Iterable[pd.DataFrame]:
8084
description = node_data["description"]
8185
items.append(description)
8286

83-
print("Total descriptions to quiz: %d", len(items))
8487
logger.info("Total descriptions to quiz: %d", len(items))
8588

8689
for i in range(0, len(items), self.concurrency_limit):

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ fastapi
2121
trafilatura
2222
aiohttp
2323
socksio
24+
pydantic
25+
ray==2.52.1
2426

2527
leidenalg
2628
igraph

0 commit comments

Comments
 (0)