@@ -39,27 +39,32 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
3939 set_logger (log_file , if_stream = True )
4040 os .environ .update ({k : str (v ) for k , v in env .items ()})
4141
42- graph_gen = GraphGen (working_dir = working_dir , config = config )
43- # Set up LLM clients
44- graph_gen .synthesizer_llm_client = OpenAIClient (
42+ tokenizer_instance = Tokenizer (config .get ("tokenizer" , "cl100k_base" ))
43+ synthesizer_llm_client = OpenAIClient (
4544 model_name = env .get ("SYNTHESIZER_MODEL" , "" ),
4645 base_url = env .get ("SYNTHESIZER_BASE_URL" , "" ),
4746 api_key = env .get ("SYNTHESIZER_API_KEY" , "" ),
4847 request_limit = True ,
4948 rpm = RPM (env .get ("RPM" , 1000 )),
5049 tpm = TPM (env .get ("TPM" , 50000 )),
50+ tokenizer = tokenizer_instance ,
5151 )
52-
53- graph_gen .trainee_llm_client = OpenAIClient (
52+ trainee_llm_client = OpenAIClient (
5453 model_name = env .get ("TRAINEE_MODEL" , "" ),
5554 base_url = env .get ("TRAINEE_BASE_URL" , "" ),
5655 api_key = env .get ("TRAINEE_API_KEY" , "" ),
5756 request_limit = True ,
5857 rpm = RPM (env .get ("RPM" , 1000 )),
5958 tpm = TPM (env .get ("TPM" , 50000 )),
59+ tokenizer = tokenizer_instance ,
6060 )
6161
62- graph_gen .tokenizer_instance = Tokenizer (config .get ("tokenizer" , "cl100k_base" ))
62+ graph_gen = GraphGen (
63+ working_dir = working_dir ,
64+ tokenizer_instance = tokenizer_instance ,
65+ synthesizer_llm_client = synthesizer_llm_client ,
66+ trainee_llm_client = trainee_llm_client ,
67+ )
6368
6469 return graph_gen
6570
@@ -78,27 +83,32 @@ def sum_tokens(client):
7883 "chunk_size" : params .chunk_size ,
7984 "chunk_overlap" : params .chunk_overlap ,
8085 },
81- "output_data_type" : params .output_data_type ,
82- "output_data_format" : params .output_data_format ,
83- "tokenizer" : params .tokenizer ,
8486 "search" : {"enabled" : False },
85- "quiz_and_judge_strategy " : {
87+ "quiz_and_judge " : {
8688 "enabled" : params .if_trainee_model ,
8789 "quiz_samples" : params .quiz_samples ,
8890 },
89- "traverse_strategy" : {
90- "bidirectional" : params .bidirectional ,
91- "expand_method" : params .expand_method ,
92- "max_extra_edges" : params .max_extra_edges ,
93- "max_tokens" : params .max_tokens ,
94- "max_depth" : params .max_depth ,
95- "edge_sampling" : params .edge_sampling ,
96- "isolated_node_strategy" : params .isolated_node_strategy ,
97- "loss_strategy" : params .loss_strategy ,
91+ "partition" : {
92+ "method" : "ece" ,
93+ "method_params" : {
94+ "bidirectional" : params .bidirectional ,
95+ "expand_method" : params .expand_method ,
96+ "max_extra_edges" : params .max_extra_edges ,
97+ "max_tokens" : params .max_tokens ,
98+ "max_depth" : params .max_depth ,
99+ "edge_sampling" : params .edge_sampling ,
100+ "isolated_node_strategy" : params .isolated_node_strategy ,
101+ "loss_strategy" : params .loss_strategy ,
102+ },
103+ },
104+ "generate" : {
105+ "mode" : params .output_data_type ,
106+ "data_format" : params .output_data_format ,
98107 },
99108 }
100109
101110 env = {
111+ "TOKENIZER_MODEL" : params .tokenizer ,
102112 "SYNTHESIZER_BASE_URL" : params .synthesizer_url ,
103113 "SYNTHESIZER_MODEL" : params .synthesizer_model ,
104114 "TRAINEE_BASE_URL" : params .trainee_url ,
@@ -128,19 +138,18 @@ def sum_tokens(client):
128138
129139 try :
130140 # Process the data
131- graph_gen .insert ()
141+ graph_gen .insert (read_config = config [ "read" ], split_config = config [ "split" ] )
132142
133143 if config ["if_trainee_model" ]:
134- # Generate quiz
135- graph_gen .quiz ()
136-
137- # Judge statements
138- graph_gen .judge ()
144+ # Quiz and Judge
145+ graph_gen .quiz_and_judge (quiz_and_judge_config = config ["quiz_and_judge" ])
139146 else :
140- graph_gen . traverse_strategy . edge_sampling = "random"
147+ config [ "partition" ][ "method_params" ][ " edge_sampling" ] = "random"
141148
142- # Traverse graph
143- graph_gen .traverse ()
149+ graph_gen .generate (
150+ partition_config = config ["partition" ],
151+ generate_config = config ["generate" ],
152+ )
144153
145154 # Save output
146155 output_data = graph_gen .qa_storage .data
@@ -249,6 +258,9 @@ def sum_tokens(client):
249258 )
250259
251260 with gr .Accordion (label = _ ("Model Config" ), open = False ):
261+ tokenizer = gr .Textbox (
262+ label = "Tokenizer" , value = "cl100k_base" , interactive = True
263+ )
252264 synthesizer_url = gr .Textbox (
253265 label = "Synthesizer URL" ,
254266 value = "https://api.siliconflow.cn/v1" ,
@@ -300,9 +312,6 @@ def sum_tokens(client):
300312 step = 100 ,
301313 interactive = True ,
302314 )
303- tokenizer = gr .Textbox (
304- label = "Tokenizer" , value = "cl100k_base" , interactive = True
305- )
306315 output_data_type = gr .Radio (
307316 choices = ["atomic" , "multi_hop" , "aggregated" ],
308317 label = "Output Data Type" ,
0 commit comments