55import yaml
66from dotenv import load_dotenv
77
8- from graphgen .graphgen import GraphGen
9- from graphgen .models import OpenAIModel , Tokenizer , TraverseStrategy
10- from graphgen .utils import set_logger
8+ from .graphgen import GraphGen
9+ from .models import OpenAIModel , Tokenizer , TraverseStrategy
10+ from .utils import set_logger
1111
1212sys_path = os .path .abspath (os .path .dirname (__file__ ))
13- unique_id = int (time .time ())
14- set_logger (os .path .join (sys_path , "cache" , "logs" , f"graphgen_{ unique_id } .log" ), if_stream = False )
15- config_path = os .path .join (sys_path , "cache" , "data" , "graphgen" , str (unique_id ), f"config-{ unique_id } .yaml" )
1613
1714load_dotenv ()
1815
19- def save_config (global_config ):
16+ def set_working_dir (folder ):
17+ os .makedirs (folder , exist_ok = True )
18+ os .makedirs (os .path .join (folder , "data" , "graphgen" ), exist_ok = True )
19+ os .makedirs (os .path .join (folder , "logs" ), exist_ok = True )
20+
21+ def save_config (config_path , global_config ):
2022 if not os .path .exists (os .path .dirname (config_path )):
2123 os .makedirs (os .path .dirname (config_path ))
2224 with open (config_path , "w" , encoding = 'utf-8' ) as config_file :
@@ -28,7 +30,19 @@ def save_config(global_config):
2830 help = 'Config parameters for GraphGen.' ,
2931 default = 'graphgen_config.yaml' ,
3032 type = str )
33+ parser .add_argument ('--output_dir' ,
34+ help = 'Output directory for GraphGen.' ,
35+ default = sys_path ,
36+ required = True ,
37+ type = str )
38+
3139 args = parser .parse_args ()
40+
41+ working_dir = args .output_dir
42+ set_working_dir (working_dir )
43+ unique_id = int (time .time ())
44+ set_logger (os .path .join (working_dir , "logs" , f"graphgen_{ unique_id } .log" ), if_stream = False )
45+
3246 with open (args .config_file , "r" , encoding = 'utf-8' ) as f :
3347 config = yaml .load (f , Loader = yaml .FullLoader )
3448
@@ -59,6 +73,7 @@ def save_config(global_config):
5973 )
6074
6175 graph_gen = GraphGen (
76+ working_dir = working_dir ,
6277 unique_id = unique_id ,
6378 synthesizer_llm_client = synthesizer_llm_client ,
6479 trainee_llm_client = trainee_llm_client ,
@@ -77,4 +92,5 @@ def save_config(global_config):
7792
7893 graph_gen .traverse ()
7994
80- save_config (config )
95+ path = os .path .join (working_dir , "data" , "graphgen" , str (unique_id ), f"config-{ unique_id } .yaml" )
96+ save_config (path , config )
0 commit comments