Skip to content

Commit 9434fd3

Browse files
refactor: refactor generate pipeline
1 parent 6447f2a commit 9434fd3

File tree

8 files changed

+114
-115
lines changed

8 files changed

+114
-115
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
TOKENIZER_MODEL=
12
SYNTHESIZER_MODEL=
23
SYNTHESIZER_BASE_URL=
34
SYNTHESIZER_API_KEY=

graphgen/configs/aggregated_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
10-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: true
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 5 # maximum depth for graph traversal
22-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
ece_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 5 # maximum depth for graph traversal
21+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: aggregated # atomic, aggregated, multi_hop, cot
26+
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/configs/atomic_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: atomic # atomic, aggregated, multi_hop, cot
10-
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: true
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 3 # maximum depth for graph traversal
22-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
ece_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 3 # maximum depth for graph traversal
21+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: atomic # atomic, aggregated, multi_hop, cot
26+
data_format: Alpaca # Alpaca, Sharegpt, ChatML

graphgen/configs/cot_config.yaml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: cot # atomic, aggregated, multi_hop, cot
10-
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
method_params:
13-
method: leiden
14-
max_size: 20 # Maximum size of communities
15-
use_lcc: false
16-
random_seed: 42
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10+
enabled: false
11+
partition: # graph partition configuration
12+
method: leiden # leiden is a community detection algorithm
13+
leiden_params:
14+
max_size: 20 # Maximum size of communities
15+
use_lcc: false
16+
random_seed: 42
17+
generate:
18+
mode: cot # atomic, aggregated, multi_hop, cot
19+
data_format: Sharegpt # Alpaca, Sharegpt, ChatML

graphgen/configs/multi_hop_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
10-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: false
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 1 # maximum depth for graph traversal
22-
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
ece_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 1 # maximum depth for graph traversal
21+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: multi_hop # strategy for generating multi-hop QA pairs
26+
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/generate.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import yaml
77
from dotenv import load_dotenv
88

9-
from .graphgen import GraphGen
10-
from .utils import logger, set_logger
9+
from graphgen.graphgen import GraphGen
10+
from graphgen.utils import logger, set_logger
1111

1212
sys_path = os.path.abspath(os.path.dirname(__file__))
1313

@@ -50,12 +50,10 @@ def main():
5050
with open(args.config_file, "r", encoding="utf-8") as f:
5151
config = yaml.load(f, Loader=yaml.FullLoader)
5252

53-
output_data_type = config["output_data_type"]
53+
mode = config["generate"]["mode"]
5454
unique_id = int(time.time())
5555

56-
output_path = os.path.join(
57-
working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}"
58-
)
56+
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}_{mode}")
5957
set_working_dir(output_path)
6058

6159
set_logger(
@@ -65,35 +63,35 @@ def main():
6563
logger.info(
6664
"GraphGen with unique ID %s logging to %s",
6765
unique_id,
68-
os.path.join(
69-
working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log"
70-
),
66+
os.path.join(working_dir, f"{unique_id}.log"),
7167
)
7268

73-
graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
69+
graph_gen = GraphGen(working_dir=working_dir, output_path=output_path)
7470

75-
graph_gen.insert()
71+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
7672

77-
if config["search"]["enabled"]:
78-
graph_gen.search()
73+
graph_gen.search(search_config=config["search"])
7974

8075
# Use pipeline according to the output data type
81-
if output_data_type in ["atomic", "aggregated", "multi_hop"]:
82-
if "quiz_and_judge_strategy" in config and config[
83-
"quiz_and_judge_strategy"
84-
].get("enabled", False):
85-
graph_gen.quiz()
86-
graph_gen.judge()
76+
if mode in ["atomic", "aggregated", "multi_hop"]:
77+
logger.info("Generation mode set to '%s'. Start generation.", mode)
78+
if "quiz_and_judge" in config:
79+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
8780
else:
8881
logger.warning(
8982
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
9083
)
91-
graph_gen.traverse_strategy.edge_sampling = "random"
92-
graph_gen.traverse()
93-
elif output_data_type == "cot":
94-
graph_gen.generate_reasoning(method_params=config["method_params"])
84+
# TODO: make edge sampling random
85+
# graph_gen.traverse_strategy.edge_sampling = "random"
86+
elif mode == "cot":
87+
logger.info("Generation mode set to 'cot'. Start generation.")
9588
else:
96-
raise ValueError(f"Unsupported output data type: {output_data_type}")
89+
raise ValueError(f"Unsupported output data type: {mode}")
90+
91+
graph_gen.generate(
92+
partition_config=config["partition"],
93+
generate_config=config["generate"],
94+
)
9795

9896
save_config(os.path.join(output_path, "config.yaml"), config)
9997
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

graphgen/graphgen.py

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import os
33
import time
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
55
from typing import Dict, cast
66

77
import gradio as gr
@@ -14,7 +14,6 @@
1414
NetworkXStorage,
1515
OpenAIClient,
1616
Tokenizer,
17-
TraverseStrategy,
1817
)
1918
from graphgen.operators import (
2019
chunk_documents,
@@ -40,30 +39,24 @@
4039

4140
@dataclass
4241
class GraphGen:
43-
unique_id: int = int(time.time())
4442
working_dir: str = os.path.join(sys_path, "cache")
45-
config: Dict = field(default_factory=dict)
43+
output_path: str = os.path.join(
44+
working_dir, "data", "graphgen", str(int(time.time()))
45+
)
4646

4747
# llm
4848
tokenizer_instance: Tokenizer = None
4949
synthesizer_llm_client: OpenAIClient = None
5050
trainee_llm_client: OpenAIClient = None
5151

52-
# search
53-
search_config: dict = field(
54-
default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]}
55-
)
56-
57-
# traversal
58-
traverse_strategy: TraverseStrategy = None
59-
6052
# webui
6153
progress_bar: gr.Progress = None
6254

6355
def __post_init__(self):
6456
self.tokenizer_instance: Tokenizer = Tokenizer(
65-
model_name=self.config["tokenizer"]
57+
model_name=os.getenv("TOKENIZER_MODEL")
6658
)
59+
6760
self.synthesizer_llm_client: OpenAIClient = OpenAIClient(
6861
model_name=os.getenv("SYNTHESIZER_MODEL"),
6962
api_key=os.getenv("SYNTHESIZER_API_KEY"),
@@ -76,12 +69,6 @@ def __post_init__(self):
7669
base_url=os.getenv("TRAINEE_BASE_URL"),
7770
tokenizer=self.tokenizer_instance,
7871
)
79-
self.search_config = self.config["search"]
80-
81-
if "traverse_strategy" in self.config:
82-
self.traverse_strategy = TraverseStrategy(
83-
**self.config["traverse_strategy"]
84-
)
8572

8673
self.full_docs_storage: JsonKVStorage = JsonKVStorage(
8774
self.working_dir, namespace="full_docs"
@@ -99,24 +86,17 @@ def __post_init__(self):
9986
self.working_dir, namespace="rephrase"
10087
)
10188
self.qa_storage: JsonListStorage = JsonListStorage(
102-
os.path.join(
103-
self.working_dir,
104-
"data",
105-
"graphgen",
106-
f"{self.unique_id}_{self.config['output_data_type']}",
107-
),
89+
self.working_dir,
10890
namespace="qa",
10991
)
11092

11193
@async_to_sync_method
112-
async def insert(self):
94+
async def insert(self, read_config: Dict, split_config: Dict):
11395
"""
11496
insert chunks into the graph
11597
"""
116-
input_file = self.config["read"]["input_file"]
117-
11898
# Step 1: Read files
119-
data = read_files(input_file)
99+
data = read_files(read_config["input_file"])
120100
if len(data) == 0:
121101
logger.warning("No data to process")
122102
return
@@ -141,8 +121,8 @@ async def insert(self):
141121

142122
inserting_chunks = await chunk_documents(
143123
new_docs,
144-
self.config["split"]["chunk_size"],
145-
self.config["split"]["chunk_overlap"],
124+
split_config["chunk_size"],
125+
split_config["chunk_overlap"],
146126
self.tokenizer_instance,
147127
self.progress_bar,
148128
)
@@ -178,6 +158,7 @@ async def insert(self):
178158
return
179159

180160
await self._insert_done()
161+
return _add_entities_and_relations
181162

182163
async def _insert_done(self):
183164
tasks = []
@@ -193,14 +174,12 @@ async def _insert_done(self):
193174
await asyncio.gather(*tasks)
194175

195176
@async_to_sync_method
196-
async def search(self):
177+
async def search(self, search_config: Dict):
197178
logger.info(
198-
"Search is %s", "enabled" if self.search_config["enabled"] else "disabled"
179+
"Search is %s", "enabled" if search_config["enabled"] else "disabled"
199180
)
200-
if self.search_config["enabled"]:
201-
logger.info(
202-
"[Search] %s ...", ", ".join(self.search_config["search_types"])
203-
)
181+
if search_config["enabled"]:
182+
logger.info("[Search] %s ...", ", ".join(search_config["search_types"]))
204183
all_nodes = await self.graph_storage.get_all_nodes()
205184
all_nodes_names = [node[0] for node in all_nodes]
206185
new_search_entities = await self.full_docs_storage.filter_keys(
@@ -210,7 +189,7 @@ async def search(self):
210189
"[Search] Found %d entities to search", len(new_search_entities)
211190
)
212191
_add_search_data = await search_all(
213-
search_types=self.search_config["search_types"],
192+
search_types=search_config["search_types"],
214193
search_entities=new_search_entities,
215194
)
216195
if _add_search_data:
@@ -230,27 +209,37 @@ async def search(self):
230209
await self.insert()
231210

232211
@async_to_sync_method
233-
async def quiz(self):
234-
max_samples = self.config["quiz_and_judge_strategy"]["quiz_samples"]
212+
async def quiz_and_judge(self, quiz_and_judge_config: Dict):
213+
if quiz_and_judge_config is None or not quiz_and_judge_config.get(
214+
"enabled", False
215+
):
216+
logger.warning("Quiz and Judge is not used in this pipeline.")
217+
return
218+
max_samples = quiz_and_judge_config["quiz_samples"]
235219
await quiz(
236220
self.synthesizer_llm_client,
237221
self.graph_storage,
238222
self.rephrase_storage,
239223
max_samples,
240224
)
241-
await self.rephrase_storage.index_done_callback()
242225

243-
@async_to_sync_method
244-
async def judge(self):
245-
re_judge = self.config["quiz_and_judge_strategy"]["re_judge"]
226+
# TODO: assert trainee_llm_client is valid before judge
227+
re_judge = quiz_and_judge_config["re_judge"]
246228
_update_relations = await judge_statement(
247229
self.trainee_llm_client,
248230
self.graph_storage,
249231
self.rephrase_storage,
250232
re_judge,
251233
)
234+
await self.rephrase_storage.index_done_callback()
252235
await _update_relations.index_done_callback()
253236

237+
@async_to_sync_method
238+
async def generate(self, partition_config: Dict, generate_config: Dict):
239+
# Step 1: partition the graph
240+
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
241+
pass
242+
254243
@async_to_sync_method
255244
async def traverse(self):
256245
output_data_type = self.config["output_data_type"]

0 commit comments

Comments
 (0)