Skip to content

Commit bc07222

Browse files
fix: fix quiz params
1 parent 99a6e5f commit bc07222

File tree

9 files changed

+190
-145
lines changed

9 files changed

+190
-145
lines changed

graphgen/operators/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from .quiz import QuizService
77
from .read import read
88
from .search import search_all
9+
from .judge import JudgeService
910

1011
operators = {
1112
"read": read,
1213
"chunk": ChunkService,
1314
"build_kg": BuildKGService,
1415
"quiz": QuizService,
16+
"judge": JudgeService,
1517
"extract_info": extract_info,
1618
"search_all": search_all,
1719
"partition_kg": partition_kg,

graphgen/operators/chunk/chunk_service.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pandas as pd
66

7+
from graphgen.common import init_storage
78
from graphgen.models import (
89
ChineseRecursiveTextSplitter,
910
RecursiveCharacterSplitter,
@@ -40,9 +41,14 @@ def split_chunks(text: str, language: str = "en", **kwargs) -> list:
4041

4142

4243
class ChunkService:
43-
def __init__(self, **chunk_kwargs):
44+
def __init__(self, working_dir: str = "cache", **chunk_kwargs):
4445
tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base")
4546
self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model)
47+
self.chunk_storage = init_storage(
48+
backend="json_kv",
49+
working_dir=working_dir,
50+
namespace="chunk",
51+
)
4652
self.chunk_kwargs = chunk_kwargs
4753

4854
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
@@ -88,4 +94,8 @@ def chunk_documents(self, new_docs: list) -> list:
8894
**doc,
8995
}
9096
)
97+
self.chunk_storage.upsert(
98+
{chunk["_chunk_id"]: chunk for chunk in chunks}
99+
)
100+
self.chunk_storage.index_done_callback()
91101
return chunks
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .extract_info import extract_info
1+
from .extract import extract_info
File renamed without changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .judge_service import JudgeService

graphgen/operators/judge/judge.py

Lines changed: 0 additions & 139 deletions
This file was deleted.
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import math
2+
3+
import gradio as gr
4+
5+
from graphgen.bases import BaseLLMWrapper
6+
from graphgen.models import JsonKVStorage, NetworkXStorage
7+
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
8+
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
9+
10+
11+
import math
12+
from collections.abc import Iterable
13+
14+
import pandas as pd
15+
16+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
17+
from graphgen.common import init_llm, init_storage
18+
from graphgen.models import NetworkXStorage, JsonKVStorage
19+
from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT
20+
from graphgen.utils import logger, run_concurrent, yes_no_loss_entropy
21+
22+
23+
class JudgeService:
24+
"""Service for judging graph edges and nodes using a trainee LLM."""
25+
def __init__(self, working_dir: str = "cache"):
26+
self.llm_client: BaseLLMWrapper = init_llm("trainee")
27+
28+
def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
29+
return pd.DataFrame([{"status": "judging_completed"}])
30+
31+
def judge(self) -> Iterable[pd.DataFrame]:
32+
"""
33+
Judge the statements in the graph storage
34+
35+
:param re_judge: re-judge the relations
36+
:return:
37+
"""
38+
return
39+
40+
41+
42+
# async def judge_statement( # pylint: disable=too-many-statements
43+
# trainee_llm_client: BaseLLMWrapper,
44+
# graph_storage: NetworkXStorage,
45+
# rephrase_storage: JsonKVStorage,
46+
# re_judge: bool = False,
47+
# progress_bar: gr.Progress = None,
48+
# ) -> NetworkXStorage:
49+
# """
50+
# Get all edges and nodes and judge them
51+
#
52+
# :param trainee_llm_client: judge the statements to get comprehension loss
53+
# :param graph_storage: graph storage instance
54+
# :param rephrase_storage: rephrase storage instance
55+
# :param re_judge: re-judge the relations
56+
# :param progress_bar
57+
# :return:
58+
# """
59+
#
60+
# async def _judge_single_relation(
61+
# edge: tuple,
62+
# ):
63+
# source_id = edge[0]
64+
# target_id = edge[1]
65+
# edge_data = edge[2]
66+
#
67+
# if (not re_judge) and "loss" in edge_data and edge_data["loss"] is not None:
68+
# logger.debug(
69+
# "Edge %s -> %s already judged, loss: %s, skip",
70+
# source_id,
71+
# target_id,
72+
# edge_data["loss"],
73+
# )
74+
# return source_id, target_id, edge_data
75+
#
76+
# description = edge_data["description"]
77+
#
78+
# try:
79+
# descriptions = rephrase_storage.get_by_id(description)
80+
# assert descriptions is not None
81+
#
82+
# judgements = []
83+
# gts = [gt for _, gt in descriptions]
84+
# for description, gt in descriptions:
85+
# judgement = await trainee_llm_client.generate_topk_per_token(
86+
# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
87+
# )
88+
# judgements.append(judgement[0].top_candidates)
89+
#
90+
# loss = yes_no_loss_entropy(judgements, gts)
91+
#
92+
# logger.debug(
93+
# "Edge %s -> %s description: %s loss: %s",
94+
# source_id,
95+
# target_id,
96+
# description,
97+
# loss,
98+
# )
99+
#
100+
# edge_data["loss"] = loss
101+
# except Exception as e: # pylint: disable=broad-except
102+
# logger.error(
103+
# "Error in judging relation %s -> %s: %s", source_id, target_id, e
104+
# )
105+
# logger.info("Use default loss 0.1")
106+
# edge_data["loss"] = -math.log(0.1)
107+
#
108+
# graph_storage.update_edge(source_id, target_id, edge_data)
109+
# return source_id, target_id, edge_data
110+
#
111+
# edges = graph_storage.get_all_edges()
112+
#
113+
# await run_concurrent(
114+
# _judge_single_relation,
115+
# edges,
116+
# desc="Judging relations",
117+
# unit="relation",
118+
# progress_bar=progress_bar,
119+
# )
120+
#
121+
# async def _judge_single_entity(
122+
# node: tuple,
123+
# ):
124+
# node_id = node[0]
125+
# node_data = node[1]
126+
#
127+
# if (not re_judge) and "loss" in node_data and node_data["loss"] is not None:
128+
# logger.debug(
129+
# "Node %s already judged, loss: %s, skip", node_id, node_data["loss"]
130+
# )
131+
# return node_id, node_data
132+
#
133+
# description = node_data["description"]
134+
#
135+
# try:
136+
# descriptions = rephrase_storage.get_by_id(description)
137+
# assert descriptions is not None
138+
#
139+
# judgements = []
140+
# gts = [gt for _, gt in descriptions]
141+
# for description, gt in descriptions:
142+
# judgement = await trainee_llm_client.generate_topk_per_token(
143+
# STATEMENT_JUDGEMENT_PROMPT["TEMPLATE"].format(statement=description)
144+
# )
145+
# judgements.append(judgement[0].top_candidates)
146+
#
147+
# loss = yes_no_loss_entropy(judgements, gts)
148+
#
149+
# logger.debug("Node %s description: %s loss: %s", node_id, description, loss)
150+
#
151+
# node_data["loss"] = loss
152+
# except Exception as e: # pylint: disable=broad-except
153+
# logger.error("Error in judging entity %s: %s", node_id, e)
154+
# logger.error("Use default loss 0.1")
155+
# node_data["loss"] = -math.log(0.1)
156+
#
157+
# graph_storage.update_node(node_id, node_data)
158+
# return node_id, node_data
159+
#
160+
# nodes = graph_storage.get_all_nodes()
161+
#
162+
# await run_concurrent(
163+
# _judge_single_entity,
164+
# nodes,
165+
# desc="Judging entities",
166+
# unit="entity",
167+
# progress_bar=progress_bar,
168+
# )
169+
#
170+
# return graph_storage
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .quiz import QuizService
1+
from .quiz_service import QuizService
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
66
from graphgen.common import init_llm, init_storage
77
from graphgen.models import QuizGenerator
8-
from graphgen.utils import compute_content_hash, logger, run_concurrent
8+
from graphgen.utils import compute_content_hash, run_concurrent, logger
99

1010

1111
class QuizService:
12-
def __init__(self, working_dir: str = "cache", quiz_samples: int = 1):
12+
def __init__(self, working_dir: str = "cache", quiz_samples: int = 1, concurrency_limit: int = 200):
1313
self.quiz_samples = quiz_samples
1414
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
1515
self.graph_storage: BaseGraphStorage = init_storage(
@@ -21,7 +21,7 @@ def __init__(self, working_dir: str = "cache", quiz_samples: int = 1):
2121
)
2222
self.generator = QuizGenerator(self.llm_client)
2323

24-
self.concurrency_limit = 20
24+
self.concurrency_limit = concurrency_limit
2525

2626
def __call__(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]:
2727
# this operator does not consume any batch data
@@ -80,6 +80,7 @@ def quiz(self) -> Iterable[pd.DataFrame]:
8080
description = node_data["description"]
8181
items.append(description)
8282

83+
print("Total descriptions to quiz: %d", len(items))
8384
logger.info("Total descriptions to quiz: %d", len(items))
8485

8586
for i in range(0, len(items), self.concurrency_limit):

0 commit comments

Comments
 (0)