@@ -160,7 +160,7 @@ def run_graphgen(
160160 return f"Error occurred: { str (e )} "
161161
162162# Create Gradio interface
163- with gr .Blocks (title = "GraphGen Demo" , theme = gr .themes .Citrus (), css = css ) as demo :
163+ with gr .Blocks (title = "GraphGen Demo" , theme = gr .themes .Base (), css = css ) as demo :
164164 # Header
165165 gr .Image (
166166 value = f"{ root_dir } /resources/images/logo.png" ,
@@ -212,32 +212,23 @@ def run_graphgen(
212212 "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _ ("Intro" )
213213 )
214214
215+ # Model Configuration Column
216+ with gr .Accordion (label = _ ("Model Config" ), open = False ):
217+ base_url = gr .Textbox (label = "Base URL" , value = "https://api.siliconflow.cn/v1" ,
218+ info = _ ("Base URL Info" ), interactive = True )
219+ synthesizer_model = gr .Textbox (label = "Synthesizer Model" , value = "Qwen/Qwen2.5-72B-Instruct" ,
220+ info = _ ("Synthesizer Model Info" ), interactive = True )
221+ trainee_model = gr .Textbox (label = "Trainee Model" , value = "Qwen/Qwen2.5-7B-Instruct" ,
222+ info = _ ("Trainee Model Info" ), interactive = True )
223+
224+ with gr .Accordion (label = _ ("Generation Config" ), open = False ):
225+ qa_form = gr .Radio (choices = ["atomic" , "multi_hop" , "open" ], label = "QA Form" ,
226+ value = "multi_hop" , interactive = True )
227+ tokenizer = gr .Textbox (label = "Tokenizer" , value = "cl100k_base" )
228+ # web_search = gr.Checkbox(label="Enable Web Search", value=False)
229+ # quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1)
215230
216- with gr .Row ():
217- # Model Configuration Column
218231 with gr .Column (scale = 1 ):
219- gr .Markdown ("### Model Configuration" )
220- synthesizer_model = gr .Textbox (label = "Synthesizer Model" , value = "" )
221- synthesizer_base_url = gr .Textbox (label = "Synthesizer Base URL" , value = "" )
222- synthesizer_api_key = gr .Textbox (label = "Synthesizer API Key" , type = "password" , value = "" )
223- trainee_model = gr .Textbox (label = "Trainee Model" , value = "" )
224- trainee_base_url = gr .Textbox (label = "Trainee Base URL" , value = "" )
225- trainee_api_key = gr .Textbox (label = "Trainee API Key" , type = "password" , value = "" )
226- test_connection_btn = gr .Button ("Test Connection" )
227-
228- # Input Configuration Column
229- with gr .Column (scale = 1 ):
230- gr .Markdown ("### Input Configuration" )
231- input_file = gr .Textbox (label = "Input File Path" , value = "resources/examples/raw_demo.jsonl" )
232- data_type = gr .Radio (choices = ["raw" , "chunked" ], label = "Data Type" , value = "raw" )
233- qa_form = gr .Radio (choices = ["atomic" , "multi_hop" , "open" ], label = "QA Form" , value = "multi_hop" )
234- tokenizer = gr .Textbox (label = "Tokenizer" , value = "cl100k_base" )
235- web_search = gr .Checkbox (label = "Enable Web Search" , value = False )
236- quiz_samples = gr .Number (label = "Quiz Samples" , value = 2 , minimum = 1 )
237-
238- # Traverse Strategy Column
239- with gr .Column (scale = 1 ):
240- gr .Markdown ("### Traverse Strategy" )
241232 expand_method = gr .Radio (choices = ["max_width" , "max_tokens" ], label = "Expand Method" , value = "max_tokens" )
242233 bidirectional = gr .Checkbox (label = "Bidirectional" , value = True )
243234 max_extra_edges = gr .Slider (minimum = 1 , maximum = 10 , value = 5 , label = "Max Extra Edges" , step = 1 )
@@ -248,36 +239,56 @@ def run_graphgen(
248239 isolated_node_strategy = gr .Radio (choices = ["add" , "ignore" ], label = "Isolated Node Strategy" , value = "ignore" )
249240 difficulty_level = gr .Radio (choices = ["easy" , "medium" , "hard" ], label = "Difficulty Level" , value = "medium" )
250241
251- # Submission and Output Rows
252- with gr .Row ():
253- submit_btn = gr .Button ("Run GraphGen" )
254- with gr .Row ():
255- output = gr .Textbox (label = "Output" )
242+ with gr .Row (equal_height = True ):
243+ with gr .Column (scale = 3 ):
244+ api_key = gr .Textbox (label = "SiliconFlow API Key" , type = "password" , value = "" )
245+ with gr .Column (scale = 1 ):
246+ test_connection_btn = gr .Button ("Test Connection" )
247+
248+ with gr .Blocks ():
249+ with gr .Row (equal_height = True ):
250+ with gr .Column (scale = 1 ):
251+ upload_btn = gr .File (
252+ label = "Upload File" ,
253+ file_count = "multiple" ,
254+ value = ["translation.json" , "translation.json" ],
255+ interactive = True ,
256+ )
257+ with gr .Column (scale = 1 ):
258+ download_file = gr .File (
259+ label = "Output" ,
260+ file_count = "single" ,
261+ interactive = False ,
262+ )
263+
264+ submit_btn = gr .Button ("Run GraphGen" )
256265
257266 # Test Connection
258267 test_connection_btn .click (
259268 test_api_connection ,
260- inputs = [synthesizer_base_url , synthesizer_api_key , synthesizer_model ],
261- outputs = output
262- ).then (
269+ inputs = [base_url , api_key , synthesizer_model ],
270+ outputs = []
271+ )
272+
273+ test_connection_btn .click (
263274 test_api_connection ,
264- inputs = [trainee_base_url , trainee_api_key , trainee_model ],
265- outputs = output
275+ inputs = [base_url , api_key , trainee_model ],
276+ outputs = []
266277 )
267278
268279 # Event Handling
269- submit_btn .click (
270- run_graphgen ,
271- inputs = [
272- input_file , data_type , qa_form , tokenizer , web_search ,
273- expand_method , bidirectional , max_extra_edges , max_tokens ,
274- max_depth , edge_sampling , isolated_node_strategy , difficulty_level ,
275- synthesizer_model , synthesizer_base_url , synthesizer_api_key ,
276- trainee_model , trainee_base_url , trainee_api_key ,
277- quiz_samples
278- ],
279- outputs = output
280- )
280+ # submit_btn.click(
281+ # run_graphgen,
282+ # inputs=[
283+ # input_file, data_type, qa_form, tokenizer, web_search,
284+ # expand_method, bidirectional, max_extra_edges, max_tokens,
285+ # max_depth, edge_sampling, isolated_node_strategy, difficulty_level,
286+ # synthesizer_model, synthesizer_base_url, synthesizer_api_key,
287+ # trainee_model, trainee_base_url, trainee_api_key,
288+ # quiz_samples
289+ # ],
290+ # outputs=output
291+ # )
281292
282293if __name__ == "__main__" :
283294 demo .launch ()
0 commit comments