|
6 | 6 | import yaml |
7 | 7 | from dotenv import load_dotenv |
8 | 8 |
|
9 | | -from .graphgen import GraphGen |
10 | | -from .utils import logger, set_logger |
| 9 | +from graphgen.graphgen import GraphGen |
| 10 | +from graphgen.utils import logger, set_logger |
11 | 11 |
|
12 | 12 | sys_path = os.path.abspath(os.path.dirname(__file__)) |
13 | 13 |
|
@@ -50,50 +50,51 @@ def main(): |
50 | 50 | with open(args.config_file, "r", encoding="utf-8") as f: |
51 | 51 | config = yaml.load(f, Loader=yaml.FullLoader) |
52 | 52 |
|
53 | | - output_data_type = config["output_data_type"] |
| 53 | + mode = config["generate"]["mode"] |
54 | 54 | unique_id = int(time.time()) |
55 | 55 |
|
56 | | - output_path = os.path.join( |
57 | | - working_dir, "data", "graphgen", f"{unique_id}_{output_data_type}" |
58 | | - ) |
| 56 | + output_path = os.path.join(working_dir, "data", "graphgen", f"{unique_id}") |
59 | 57 | set_working_dir(output_path) |
60 | 58 |
|
61 | 59 | set_logger( |
62 | | - os.path.join(output_path, f"{unique_id}.log"), |
| 60 | + os.path.join(output_path, f"{unique_id}_{mode}.log"), |
63 | 61 | if_stream=True, |
64 | 62 | ) |
65 | 63 | logger.info( |
66 | 64 | "GraphGen with unique ID %s logging to %s", |
67 | 65 | unique_id, |
68 | | - os.path.join( |
69 | | - working_dir, "logs", f"{unique_id}_graphgen_{output_data_type}.log" |
70 | | - ), |
| 66 | + os.path.join(working_dir, f"{unique_id}_{mode}.log"), |
71 | 67 | ) |
72 | 68 |
|
73 | | - graph_gen = GraphGen(working_dir=working_dir, unique_id=unique_id, config=config) |
| 69 | + graph_gen = GraphGen(unique_id=unique_id, working_dir=working_dir) |
74 | 70 |
|
75 | | - graph_gen.insert() |
| 71 | + graph_gen.insert(read_config=config["read"], split_config=config["split"]) |
76 | 72 |
|
77 | | - if config["search"]["enabled"]: |
78 | | - graph_gen.search() |
| 73 | + graph_gen.search(search_config=config["search"]) |
79 | 74 |
|
80 | 75 | # Use pipeline according to the output data type |
81 | | - if output_data_type in ["atomic", "aggregated", "multi_hop"]: |
82 | | - if "quiz_and_judge_strategy" in config and config[ |
83 | | - "quiz_and_judge_strategy" |
84 | | - ].get("enabled", False): |
85 | | - graph_gen.quiz() |
86 | | - graph_gen.judge() |
| 76 | + if mode in ["atomic", "aggregated", "multi_hop"]: |
| 77 | + logger.info("Generation mode set to '%s'. Start generation.", mode) |
| 78 | + if "quiz_and_judge" in config and config["quiz_and_judge"]["enabled"]: |
| 79 | + graph_gen.quiz_and_judge(quiz_and_judge_config=config["quiz_and_judge"]) |
87 | 80 | else: |
88 | 81 | logger.warning( |
89 | 82 | "Quiz and Judge strategy is disabled. Edge sampling falls back to random." |
90 | 83 | ) |
91 | | - graph_gen.traverse_strategy.edge_sampling = "random" |
92 | | - graph_gen.traverse() |
93 | | - elif output_data_type == "cot": |
94 | | - graph_gen.generate_reasoning(method_params=config["method_params"]) |
| 84 | + assert ( |
| 85 | + config["partition"]["method"] == "ece" |
| 86 | + and "ece_params" in config["partition"] |
| 87 | + ), "Only ECE partition with edge sampling is supported." |
| 88 | + config["partition"]["method_params"]["edge_sampling"] = "random" |
| 89 | + elif mode == "cot": |
| 90 | + logger.info("Generation mode set to 'cot'. Start generation.") |
95 | 91 | else: |
96 | | - raise ValueError(f"Unsupported output data type: {output_data_type}") |
| 92 | + raise ValueError(f"Unsupported output data type: {mode}") |
| 93 | + |
| 94 | + graph_gen.generate( |
| 95 | + partition_config=config["partition"], |
| 96 | + generate_config=config["generate"], |
| 97 | + ) |
97 | 98 |
|
98 | 99 | save_config(os.path.join(output_path, "config.yaml"), config) |
99 | 100 | logger.info("GraphGen completed successfully. Data saved to %s", output_path) |
|
0 commit comments