Skip to content

Commit 948b610

Browse files
feat(graphgen): add atomic generation method
1 parent e7da0a0 commit 948b610

File tree

6 files changed

+120
-5
lines changed

6 files changed

+120
-5
lines changed

configs/config.yaml.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
qa_form: atomic
12
data_type: raw
23
input_file: resources/examples/raw_demo.jsonl
34
tokenizer: cl100k_base

configs/graphgen_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
qa_form: atomic
12
data_type: raw
23
input_file: resources/examples/raw_demo.jsonl
34
tokenizer: cl100k_base

graphgen/graphgen.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from models import Chunk, JsonKVStorage, OpenAIModel, NetworkXStorage, WikiSearch, Tokenizer, TraverseStrategy
1111
from models.storage.base_storage import StorageNameSpace
1212
from utils import create_event_loop, logger, compute_content_hash
13-
from .operators import extract_kg, search_wikipedia, quiz, judge_statement, traverse_graph_by_edge
13+
from .operators import (extract_kg, search_wikipedia, quiz, judge_statement, traverse_graph_by_edge,
14+
traverse_graph_atomically)
1415

1516

1617
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -188,7 +189,14 @@ def traverse(self):
188189
loop.run_until_complete(self.async_traverse())
189190

190191
async def async_traverse(self):
191-
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
192-
self.graph_storage, self.traverse_strategy, self.text_chunks_storage)
192+
if self.traverse_strategy.qa_form == "atomic":
193+
results = await traverse_graph_atomically(self.synthesizer_llm_client,
194+
self.tokenizer_instance,
195+
self.graph_storage,
196+
self.traverse_strategy,
197+
self.text_chunks_storage)
198+
else:
199+
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
200+
self.graph_storage, self.traverse_strategy, self.text_chunks_storage)
193201
await self.qa_storage.upsert(results)
194202
await self.qa_storage.index_done_callback()

graphgen/operators/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from .quiz import quiz
33
from .judge import judge_statement
44
from .search_wikipedia import search_wikipedia
5-
from .traverse_graph import traverse_graph_by_edge
5+
from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically
66

77
__all__ = [
88
"extract_kg",
99
"quiz",
1010
"judge_statement",
1111
"search_wikipedia",
12-
"traverse_graph_by_edge"
12+
"traverse_graph_by_edge",
13+
"traverse_graph_atomically"
1314
]

graphgen/operators/traverse_graph.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import asyncio
2+
3+
from numba.scripts.generate_lower_listing import description
24
from tqdm.asyncio import tqdm as tqdm_async
35

46
from models import OpenAIModel, NetworkXStorage, TraverseStrategy, Tokenizer, JsonKVStorage
@@ -296,3 +298,103 @@ async def _process_single_batch(
296298
logger.error("Error occurred while processing batches: %s", e)
297299

298300
return results
301+
302+
303+
async def traverse_graph_atomically(
304+
llm_client: OpenAIModel,
305+
tokenizer: Tokenizer,
306+
graph_storage: NetworkXStorage,
307+
traverse_strategy: TraverseStrategy,
308+
text_chunks_storage: JsonKVStorage,
309+
max_concurrent: int = 1000
310+
) -> dict:
311+
"""
312+
Traverse the graph atomicly
313+
314+
:param llm_client
315+
:param tokenizer
316+
:param graph_storage
317+
:param traverse_strategy
318+
:param text_chunks_storage
319+
:param max_concurrent
320+
:return: question and answer
321+
"""
322+
323+
assert traverse_strategy.qa_form == "atomic"
324+
325+
semaphore = asyncio.Semaphore(max_concurrent)
326+
327+
async def _generate_question(
328+
node_or_edge: tuple
329+
):
330+
if len(node_or_edge) == 2:
331+
des = node_or_edge[0] + ": " + node_or_edge[1]['description']
332+
answer = node_or_edge[1]['description']
333+
else:
334+
des = node_or_edge[2]['description']
335+
answer = node_or_edge[2]['description']
336+
337+
async with semaphore:
338+
try:
339+
language = "Chinese" if detect_main_language(des) == "zh" else "English"
340+
question = await llm_client.generate_answer(
341+
QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
342+
answer=des
343+
)
344+
)
345+
if question.startswith("Question:"):
346+
question = question[len("Question:"):].strip()
347+
elif question.startswith("问题:"):
348+
question = question[len("问题:"):].strip()
349+
350+
question = question.strip("\"")
351+
answer = answer.strip("\"")
352+
353+
logger.info("Question: %s", question)
354+
logger.info("Answer: %s", answer)
355+
return {
356+
compute_content_hash(question): {
357+
"question": question,
358+
"answer": answer,
359+
"loss": -1,
360+
"difficulty": "medium"
361+
}
362+
}
363+
except Exception as e: # pylint: disable=broad-except
364+
logger.error("Error occurred while generating question: %s", e)
365+
return {}
366+
367+
results = {}
368+
edges = list(await graph_storage.get_all_edges())
369+
nodes = list(await graph_storage.get_all_nodes())
370+
371+
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
372+
373+
# TODO: 需要把node的name也加进去,或者只用edge,两种都试一下
374+
tasks = []
375+
# des中可能会有SEP分割符
376+
for node in nodes:
377+
if "<SEP>" in node[1]['description']:
378+
description_list = node[1]['description'].split("<SEP>")
379+
for item in description_list:
380+
tasks.append((node[0], {"description": item}))
381+
else:
382+
tasks.append((node[0], node[1]))
383+
for edge in edges:
384+
if "<SEP>" in edge[2]['description']:
385+
description_list = edge[2]['description'].split("<SEP>")
386+
for item in description_list:
387+
tasks.append((edge[0], edge[1], {"description": item}))
388+
else:
389+
tasks.append((edge[0], edge[1], edge[2]))
390+
391+
for result in tqdm_async(
392+
asyncio.as_completed([_generate_question(task) for task in tasks]),
393+
total=len(tasks),
394+
desc="Generating questions"
395+
):
396+
try:
397+
results.update(await result)
398+
except Exception as e: # pylint: disable=broad-except
399+
logger.error("Error occurred while generating questions: %s", e)
400+
return results

models/strategy/travserse_strategy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
@dataclass
77
class TraverseStrategy(BaseStrategy):
8+
# 生成的QA形式:原子、多跳、开放性
9+
qa_form: str = "atomic"
810
# 最大边数和最大token数方法中选择一个生效
911
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
1012
# 单向拓展还是双向拓展

0 commit comments

Comments
 (0)