@@ -24,35 +24,24 @@ def save_config(global_config):
2424
2525if __name__ == '__main__' :
2626 parser = argparse .ArgumentParser ()
27- parser .add_argument ('--input_file ' ,
28- help = 'Raw context jsonl path .' ,
29- default = 'resources/examples/chunked_demo.json ' ,
27+ parser .add_argument ('--config_file ' ,
28+ help = 'Config parameters for GraphGen .' ,
29+ default = 'graphgen_config.yaml ' ,
3030 type = str )
31- parser .add_argument ('--data_type' ,
32- help = 'Data type of input file. (Raw context or chunked context)' ,
33- choices = ['raw' , 'chunked' ],
34- default = 'raw' ,
35- type = str )
36- parser .add_argument ('--web_search' ,
37- help = 'Search node info from wiki.' ,
38- action = 'store_true' ,
39- default = False )
40- parser .add_argument ('--tokenizer' ,
41- help = 'Tokenizer name.' ,
42- default = 'cl100k_base' ,
43- type = str )
44-
4531 args = parser .parse_args ()
46- input_file = args .input_file
32+ with open (args .config_file , "r" , encoding = 'utf-8' ) as f :
33+ config = yaml .load (f , Loader = yaml .FullLoader )
4734
48- if args .data_type == 'raw' :
35+ input_file = config ['input_file' ]
36+
37+ if config ['data_type' ] == 'raw' :
4938 with open (input_file , "r" , encoding = 'utf-8' ) as f :
5039 data = [json .loads (line ) for line in f ]
51- elif args . data_type == 'chunked' :
40+ elif config [ ' data_type' ] == 'chunked' :
5241 with open (input_file , "r" , encoding = 'utf-8' ) as f :
5342 data = json .load (f )
5443 else :
55- raise ValueError (f"Invalid data type: { args . data_type } " )
44+ raise ValueError (f"Invalid data type: { config [ ' data_type' ] } " )
5645
5746 synthesizer_llm_client = OpenAIModel (
5847 model_name = os .getenv ("TEACHER_MODEL" ),
@@ -65,36 +54,27 @@ def save_config(global_config):
6554 base_url = os .getenv ("STUDENT_BASE_URL" )
6655 )
6756
68- traverse_strategy = TraverseStrategy ()
57+ traverse_strategy = TraverseStrategy (
58+ ** config ['traverse_strategy' ]
59+ )
6960
7061 graph_gen = GraphGen (
7162 unique_id = unique_id ,
72- teacher_llm_client = synthesizer_llm_client ,
73- student_llm_client = training_llm_client ,
74- if_web_search = args . web_search ,
63+ synthesizer_llm_client = synthesizer_llm_client ,
64+ training_llm_client = training_llm_client ,
65+ if_web_search = config [ ' web_search' ] ,
7566 tokenizer_instance = Tokenizer (
76- model_name = args . tokenizer
67+ model_name = config [ ' tokenizer' ]
7768 ),
7869 traverse_strategy = traverse_strategy
7970 )
8071
81- graph_gen .insert (data , args . data_type )
72+ graph_gen .insert (data , config [ ' data_type' ] )
8273
8374 graph_gen .quiz (max_samples = 2 )
8475
8576 graph_gen .judge (re_judge = False )
8677
8778 graph_gen .traverse ()
8879
89- config = {
90- "unique_id" : unique_id ,
91- "input_file" : input_file ,
92- "data_type" : args .data_type ,
93- "web_search" : args .web_search ,
94- "tokenizer" : args .tokenizer ,
95- "teacher_model" : os .getenv ("TEACHER_MODEL" ),
96- "student_model" : os .getenv ("STUDENT_MODEL" ),
97- }
98- traverse_strategy_yaml = traverse_strategy .to_yaml ()
99- config .update (traverse_strategy_yaml )
10080 save_config (config )
0 commit comments