Skip to content

Commit b5ed725

Browse files
Merge pull request #57 from open-sciencelab/refactor-pipeline
refactor: refactor generate pipeline
2 parents 6447f2a + fb3fc25 commit b5ed725

File tree

14 files changed

+205
-237
lines changed

14 files changed

+205
-237
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
TOKENIZER_MODEL=
12
SYNTHESIZER_MODEL=
23
SYNTHESIZER_BASE_URL=
34
SYNTHESIZER_API_KEY=

graphgen/configs/aggregated_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
10-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: true
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 5 # maximum depth for graph traversal
22-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
method_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 5 # maximum depth for graph traversal
21+
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: aggregated # atomic, aggregated, multi_hop, cot
26+
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/configs/atomic_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: atomic # atomic, aggregated, multi_hop, cot
10-
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: true
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 3 # maximum depth for graph traversal
22-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
method_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 3 # maximum depth for graph traversal
21+
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: atomic # atomic, aggregated, multi_hop, cot
26+
data_format: Alpaca # Alpaca, Sharegpt, ChatML

graphgen/configs/cot_config.yaml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: cot # atomic, aggregated, multi_hop, cot
10-
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
method_params:
13-
method: leiden
14-
max_size: 20 # Maximum size of communities
15-
use_lcc: false
16-
random_seed: 42
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
10+
enabled: false
11+
partition: # graph partition configuration
12+
method: leiden # leiden is a community detection algorithm
13+
method_params:
14+
max_size: 20 # Maximum size of communities
15+
use_lcc: false
16+
random_seed: 42
17+
generate:
18+
mode: cot # atomic, aggregated, multi_hop, cot
19+
data_format: Sharegpt # Alpaca, Sharegpt, ChatML

graphgen/configs/multi_hop_config.yaml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@ split:
66
search: # web search configuration
77
enabled: false # whether to enable web search
88
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9-
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
10-
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
11-
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
12-
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
9+
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1310
enabled: false
1411
quiz_samples: 2 # number of quiz samples to generate
1512
re_judge: false # whether to re-judge the existing quiz samples
16-
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
17-
bidirectional: true # whether to traverse the graph in both directions
18-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
19-
expand_method: max_width # expand method, support: max_width, max_depth
20-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
21-
max_depth: 1 # maximum depth for graph traversal
22-
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
23-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
24-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
13+
partition: # graph partition configuration
14+
method: ece # ece is a custom partition method based on comprehension loss
15+
method_params:
16+
bidirectional: true # whether to traverse the graph in both directions
17+
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18+
expand_method: max_width # expand method, support: max_width, max_depth
19+
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20+
max_depth: 1 # maximum depth for graph traversal
21+
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
22+
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23+
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
24+
generate:
25+
mode: multi_hop # strategy for generating multi-hop QA pairs
26+
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/generate.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import yaml
77
from dotenv import load_dotenv
88

9-
from .graphgen import GraphGen
10-
from .utils import logger, set_logger
9+
from graphgen.graphgen import GraphGen
10+
from graphgen.utils import logger, set_logger
1111

1212
sys_path = os.path.abspath(os.path.dirname(__file__))
1313

@@ -50,50 +50,51 @@ def main():
5050
with open(args.config_file, "r", encoding="utf-8") as f:
5151
config = yaml.load(f, Loader=yaml.FullLoader)
5252

53-
output_data_type = config["output_data_type"]
53+
mode = config["generate"]["mode"]
5454
unique_id = int(time.time())
5555

56-
output_path = os.path.join(
57-
working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}"
58-
)
56+
output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}")
5957
set_working_dir(output_path)
6058

6159
set_logger(
62-
os.path.join(output_path, f"{unique_id}.log"),
60+
os.path.join(output_path, f"{unique_id}_{mode}.log"),
6361
if_stream=True,
6462
)
6563
logger.info(
6664
"GraphGen with unique ID %s logging to %s",
6765
unique_id,
68-
os.path.join(
69-
working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log"
70-
),
66+
os.path.join(working_dir, f"{unique_id}_{mode}.log"),
7167
)
7268

73-
graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config)
69+
graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir)
7470

75-
graph_gen.insert()
71+
graph_gen.insert(read_config=config["read"], split_config=config["split"])
7672

77-
if config["search"]["enabled"]:
78-
graph_gen.search()
73+
graph_gen.search(search_config=config["search"])
7974

8075
# Use pipeline according to the output data type
81-
if output_data_type in ["atomic", "aggregated", "multi_hop"]:
82-
if "quiz_and_judge_strategy" in config and config[
83-
"quiz_and_judge_strategy"
84-
].get("enabled", False):
85-
graph_gen.quiz()
86-
graph_gen.judge()
76+
if mode in ["atomic", "aggregated", "multi_hop"]:
77+
logger.info("Generation mode set to '%s'. Start generation.", mode)
78+
if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]:
79+
graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"])
8780
else:
8881
logger.warning(
8982
"Quiz and Judge strategy is disabled. Edge sampling falls back to random."
9083
)
91-
graph_gen.traverse_strategy.edge_sampling = "random"
92-
graph_gen.traverse()
93-
elif output_data_type == "cot":
94-
graph_gen.generate_reasoning(method_params=config["method_params"])
84+
assert (
85+
config["partition"]["method"] == "ece"
86+
and "ece_params" in config["partition"]
87+
), "Only ECE partition with edge sampling is supported."
88+
config["partition"]["method_params"]["edge_sampling"] = "random"
89+
elif mode == "cot":
90+
logger.info("Generation mode set to 'cot'. Start generation.")
9591
else:
96-
raise ValueError(f"Unsupported output data type: {output_data_type}")
92+
raise ValueError(f"Unsupported output data type: {mode}")
93+
94+
graph_gen.generate(
95+
partition_config=config["partition"],
96+
generate_config=config["generate"],
97+
)
9798

9899
save_config(os.path.join(output_path, "config.yaml"), config)
99100
logger.info("GraphGen completed successfully. Data saved to %s", output_path)

0 commit comments

Comments
 (0)