22import os
33import gradio as gr
44import yaml
5+
6+ import sys
7+ root_dir = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
8+ sys .path .append (root_dir )
9+
10+ from gradio_i18n import Translate , gettext as _
11+
512from graphgen .graphgen import GraphGen
613from models import OpenAIModel , Tokenizer , TraverseStrategy
7-
14+ from test_api import test_api_connection
815
916def load_config () -> dict :
1017 with open ("config.yaml" , "r" , encoding = 'utf-8' ) as f :
1118 return yaml .safe_load (f )
1219
13-
1420def save_config (config : dict ):
1521 with open ("config.yaml" , "w" , encoding = 'utf-8' ) as f :
1622 yaml .dump (config , f )
1723
18-
1924def load_env () -> dict :
2025 env = {}
2126 if os .path .exists (".env" ):
@@ -26,13 +31,11 @@ def load_env() -> dict:
2631 env [key ] = value
2732 return env
2833
29-
3034def save_env (env : dict ):
3135 with open (".env" , "w" , encoding = 'utf-8' ) as f :
3236 for key , value in env .items ():
3337 f .write (f"{ key } ={ value } \n " )
3438
35-
3639def init_graph_gen (config : dict , env : dict ) -> GraphGen :
3740 graph_gen = GraphGen ()
3841
@@ -160,62 +163,106 @@ def run_graphgen(
160163 except Exception as e : # pylint: disable=broad-except
161164 return f"Error occurred: { str (e )} "
162165
166+ config = load_env ()
163167
164168# Create Gradio interface
165- with gr .Blocks (title = "GraphGen Configuration" ) as iface :
166- with gr .Row ():
167- # Input Configuration Column
168- with gr .Column (scale = 1 ):
169- gr .Markdown ("### Input Configuration" )
170- input_file = gr .Textbox (label = "Input File Path" , value = "resources/examples/raw_demo.jsonl" )
171- data_type = gr .Radio (choices = ["raw" , "chunked" ], label = "Data Type" , value = "raw" )
172- qa_form = gr .Radio (choices = ["atomic" , "multi_hop" , "open" ], label = "QA Form" , value = "multi_hop" )
173- tokenizer = gr .Textbox (label = "Tokenizer" , value = "cl100k_base" )
174- web_search = gr .Checkbox (label = "Enable Web Search" , value = False )
175- quiz_samples = gr .Number (label = "Quiz Samples" , value = 2 , minimum = 1 )
176-
177- # Traverse Strategy Column
178- with gr .Column (scale = 1 ):
179- gr .Markdown ("### Traverse Strategy" )
180- expand_method = gr .Radio (choices = ["max_width" , "max_tokens" ], label = "Expand Method" , value = "max_tokens" )
181- bidirectional = gr .Checkbox (label = "Bidirectional" , value = True )
182- max_extra_edges = gr .Slider (minimum = 1 , maximum = 10 , value = 5 , label = "Max Extra Edges" , step = 1 )
183- max_tokens = gr .Slider (minimum = 64 , maximum = 1024 , value = 256 , label = "Max Tokens" , step = 64 )
184- max_depth = gr .Slider (minimum = 1 , maximum = 5 , value = 2 , label = "Max Depth" , step = 1 )
185- edge_sampling = gr .Radio (choices = ["max_loss" , "min_loss" , "random" ], label = "Edge Sampling" ,
186- value = "max_loss" )
187- isolated_node_strategy = gr .Radio (choices = ["add" , "ignore" ], label = "Isolated Node Strategy" , value = "ignore" )
188- difficulty_level = gr .Radio (choices = ["easy" , "medium" , "hard" ], label = "Difficulty Level" , value = "medium" )
189-
190- # Model Configuration Column
191- with gr .Column (scale = 1 ):
192- gr .Markdown ("### Model Configuration" )
193- synthesizer_model = gr .Textbox (label = "Synthesizer Model" )
194- synthesizer_base_url = gr .Textbox (label = "Synthesizer Base URL" )
195- synthesizer_api_key = gr .Textbox (label = "Synthesizer API Key" , type = "password" )
196- trainee_model = gr .Textbox (label = "Trainee Model" )
197- trainee_base_url = gr .Textbox (label = "Trainee Base URL" )
198- trainee_api_key = gr .Textbox (label = "Trainee API Key" , type = "password" )
199-
200- # Submission and Output Rows
201- with gr .Row ():
202- submit_btn = gr .Button ("Run GraphGen" )
203- with gr .Row ():
204- output = gr .Textbox (label = "Output" )
205-
206- # Event Handling
207- submit_btn .click (
208- run_graphgen ,
209- inputs = [
210- input_file , data_type , qa_form , tokenizer , web_search ,
211- expand_method , bidirectional , max_extra_edges , max_tokens ,
212- max_depth , edge_sampling , isolated_node_strategy , difficulty_level ,
213- synthesizer_model , synthesizer_base_url , synthesizer_api_key ,
214- trainee_model , trainee_base_url , trainee_api_key ,
215- quiz_samples
169+ with gr .Blocks (title = "GraphGen Demo" ) as demo :
170+ lang = gr .Radio (
171+ choices = [
172+ ("English" , "en" ),
173+ ("简体中文" , "zh" ),
216174 ],
217- outputs = output
175+ value = "en" ,
176+ label = _ ("Language" ),
177+ render = False ,
218178 )
179+ with Translate (
180+ "translation.yaml" ,
181+ lang ,
182+ placeholder_langs = ["en" , "zh" ],
183+ persistant = False , # True to save the language setting in the browser. Requires gradio >= 5.6.0
184+ ):
185+ lang .render ()
186+ # Header
187+ gr .Image (
188+ value = f"{ root_dir } /resources/images/logo.png" ,
189+ label = "GraphGen Banner" ,
190+ elem_id = "banner" ,
191+ show_label = False ,
192+ interactive = False ,
193+ )
194+ gr .Markdown (
195+ """
196+ This is a demo for the [GraphGen](https://github.com/open-sciencelab/GraphGen) project.
197+ GraphGen is a framework for synthetic data generation guided by knowledge graphs.
198+ """
199+ )
200+ with gr .Row ():
201+ # Model Configuration Column
202+ with gr .Column (scale = 1 ):
203+ gr .Markdown ("### Model Configuration" )
204+ synthesizer_model = gr .Textbox (label = "Synthesizer Model" , value = config .get ("SYNTHESIZER_MODEL" , "" ))
205+ synthesizer_base_url = gr .Textbox (label = "Synthesizer Base URL" , value = config .get ("SYNTHESIZER_BASE_URL" , "" ))
206+ synthesizer_api_key = gr .Textbox (label = "Synthesizer API Key" , type = "password" , value = config .get ("SYNTHESIZER_API_KEY" , "" ))
207+ trainee_model = gr .Textbox (label = "Trainee Model" , value = config .get ("TRAINEE_MODEL" , "" ))
208+ trainee_base_url = gr .Textbox (label = "Trainee Base URL" , value = config .get ("TRAINEE_BASE_URL" , "" ))
209+ trainee_api_key = gr .Textbox (label = "Trainee API Key" , type = "password" , value = config .get ("TRAINEE_API_KEY" , "" ))
210+ test_connection_btn = gr .Button ("Test Connection" , variant = "primary" )
211+ gr .Button ("Save Config" , variant = "primary" )
212+
213+ # Input Configuration Column
214+ with gr .Column (scale = 1 ):
215+ gr .Markdown ("### Input Configuration" )
216+ input_file = gr .Textbox (label = "Input File Path" , value = "resources/examples/raw_demo.jsonl" )
217+ data_type = gr .Radio (choices = ["raw" , "chunked" ], label = "Data Type" , value = "raw" )
218+ qa_form = gr .Radio (choices = ["atomic" , "multi_hop" , "open" ], label = "QA Form" , value = "multi_hop" )
219+ tokenizer = gr .Textbox (label = "Tokenizer" , value = "cl100k_base" )
220+ web_search = gr .Checkbox (label = "Enable Web Search" , value = False )
221+ quiz_samples = gr .Number (label = "Quiz Samples" , value = 2 , minimum = 1 )
222+
223+ # Traverse Strategy Column
224+ with gr .Column (scale = 1 ):
225+ gr .Markdown ("### Traverse Strategy" )
226+ expand_method = gr .Radio (choices = ["max_width" , "max_tokens" ], label = "Expand Method" , value = "max_tokens" )
227+ bidirectional = gr .Checkbox (label = "Bidirectional" , value = True )
228+ max_extra_edges = gr .Slider (minimum = 1 , maximum = 10 , value = 5 , label = "Max Extra Edges" , step = 1 )
229+ max_tokens = gr .Slider (minimum = 64 , maximum = 1024 , value = 256 , label = "Max Tokens" , step = 64 )
230+ max_depth = gr .Slider (minimum = 1 , maximum = 5 , value = 2 , label = "Max Depth" , step = 1 )
231+ edge_sampling = gr .Radio (choices = ["max_loss" , "min_loss" , "random" ], label = "Edge Sampling" ,
232+ value = "max_loss" )
233+ isolated_node_strategy = gr .Radio (choices = ["add" , "ignore" ], label = "Isolated Node Strategy" , value = "ignore" )
234+ difficulty_level = gr .Radio (choices = ["easy" , "medium" , "hard" ], label = "Difficulty Level" , value = "medium" )
235+
236+ # Submission and Output Rows
237+ with gr .Row ():
238+ submit_btn = gr .Button ("Run GraphGen" )
239+ with gr .Row ():
240+ output = gr .Textbox (label = "Output" )
241+
242+ # Test Connection
243+ test_connection_btn .click (
244+ test_api_connection ,
245+ inputs = [synthesizer_base_url , synthesizer_api_key , synthesizer_model ],
246+ outputs = output
247+ ).then (
248+ test_api_connection ,
249+ inputs = [trainee_base_url , trainee_api_key , trainee_model ],
250+ outputs = output
251+ )
252+
253+ # Event Handling
254+ submit_btn .click (
255+ run_graphgen ,
256+ inputs = [
257+ input_file , data_type , qa_form , tokenizer , web_search ,
258+ expand_method , bidirectional , max_extra_edges , max_tokens ,
259+ max_depth , edge_sampling , isolated_node_strategy , difficulty_level ,
260+ synthesizer_model , synthesizer_base_url , synthesizer_api_key ,
261+ trainee_model , trainee_base_url , trainee_api_key ,
262+ quiz_samples
263+ ],
264+ outputs = output
265+ )
219266
220267if __name__ == "__main__" :
221- iface .launch ()
268+ demo .launch ()
0 commit comments