Skip to content

Commit 31614cb

Browse files
feat(webui): support i18n
1 parent 4af147d commit 31614cb

File tree

3 files changed

+110
-58
lines changed

3 files changed

+110
-58
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@ torch
1515
plotly
1616
pandas
1717
gradio
18+
gradio-i18n
1819
kaleido
1920
pyyaml

webui/app.py

Lines changed: 105 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,25 @@
22
import os
33
import gradio as gr
44
import yaml
5+
6+
import sys
7+
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8+
sys.path.append(root_dir)
9+
10+
from gradio_i18n import Translate, gettext as _
11+
512
from graphgen.graphgen import GraphGen
613
from models import OpenAIModel, Tokenizer, TraverseStrategy
7-
14+
from test_api import test_api_connection
815

916
def load_config() -> dict:
1017
with open("config.yaml", "r", encoding='utf-8') as f:
1118
return yaml.safe_load(f)
1219

13-
1420
def save_config(config: dict):
1521
with open("config.yaml", "w", encoding='utf-8') as f:
1622
yaml.dump(config, f)
1723

18-
1924
def load_env() -> dict:
2025
env = {}
2126
if os.path.exists(".env"):
@@ -26,13 +31,11 @@ def load_env() -> dict:
2631
env[key] = value
2732
return env
2833

29-
3034
def save_env(env: dict):
3135
with open(".env", "w", encoding='utf-8') as f:
3236
for key, value in env.items():
3337
f.write(f"{key}={value}\n")
3438

35-
3639
def init_graph_gen(config: dict, env: dict) -> GraphGen:
3740
graph_gen = GraphGen()
3841

@@ -160,62 +163,106 @@ def run_graphgen(
160163
except Exception as e: # pylint: disable=broad-except
161164
return f"Error occurred: {str(e)}"
162165

166+
config = load_env()
163167

164168
# Create Gradio interface
165-
with gr.Blocks(title="GraphGen Configuration") as iface:
166-
with gr.Row():
167-
# Input Configuration Column
168-
with gr.Column(scale=1):
169-
gr.Markdown("### Input Configuration")
170-
input_file = gr.Textbox(label="Input File Path", value="resources/examples/raw_demo.jsonl")
171-
data_type = gr.Radio(choices=["raw", "chunked"], label="Data Type", value="raw")
172-
qa_form = gr.Radio(choices=["atomic", "multi_hop", "open"], label="QA Form", value="multi_hop")
173-
tokenizer = gr.Textbox(label="Tokenizer", value="cl100k_base")
174-
web_search = gr.Checkbox(label="Enable Web Search", value=False)
175-
quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1)
176-
177-
# Traverse Strategy Column
178-
with gr.Column(scale=1):
179-
gr.Markdown("### Traverse Strategy")
180-
expand_method = gr.Radio(choices=["max_width", "max_tokens"], label="Expand Method", value="max_tokens")
181-
bidirectional = gr.Checkbox(label="Bidirectional", value=True)
182-
max_extra_edges = gr.Slider(minimum=1, maximum=10, value=5, label="Max Extra Edges", step=1)
183-
max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, label="Max Tokens", step=64)
184-
max_depth = gr.Slider(minimum=1, maximum=5, value=2, label="Max Depth", step=1)
185-
edge_sampling = gr.Radio(choices=["max_loss", "min_loss", "random"], label="Edge Sampling",
186-
value="max_loss")
187-
isolated_node_strategy = gr.Radio(choices=["add", "ignore"], label="Isolated Node Strategy", value="ignore")
188-
difficulty_level = gr.Radio(choices=["easy", "medium", "hard"], label="Difficulty Level", value="medium")
189-
190-
# Model Configuration Column
191-
with gr.Column(scale=1):
192-
gr.Markdown("### Model Configuration")
193-
synthesizer_model = gr.Textbox(label="Synthesizer Model")
194-
synthesizer_base_url = gr.Textbox(label="Synthesizer Base URL")
195-
synthesizer_api_key = gr.Textbox(label="Synthesizer API Key", type="password")
196-
trainee_model = gr.Textbox(label="Trainee Model")
197-
trainee_base_url = gr.Textbox(label="Trainee Base URL")
198-
trainee_api_key = gr.Textbox(label="Trainee API Key", type="password")
199-
200-
# Submission and Output Rows
201-
with gr.Row():
202-
submit_btn = gr.Button("Run GraphGen")
203-
with gr.Row():
204-
output = gr.Textbox(label="Output")
205-
206-
# Event Handling
207-
submit_btn.click(
208-
run_graphgen,
209-
inputs=[
210-
input_file, data_type, qa_form, tokenizer, web_search,
211-
expand_method, bidirectional, max_extra_edges, max_tokens,
212-
max_depth, edge_sampling, isolated_node_strategy, difficulty_level,
213-
synthesizer_model, synthesizer_base_url, synthesizer_api_key,
214-
trainee_model, trainee_base_url, trainee_api_key,
215-
quiz_samples
169+
with gr.Blocks(title="GraphGen Demo") as demo:
170+
lang = gr.Radio(
171+
choices=[
172+
("English", "en"),
173+
("简体中文", "zh"),
216174
],
217-
outputs=output
175+
value="en",
176+
label=_("Language"),
177+
render=False,
218178
)
179+
with Translate(
180+
"translation.yaml",
181+
lang,
182+
placeholder_langs=["en", "zh"],
183+
persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0
184+
):
185+
lang.render()
186+
# Header
187+
gr.Image(
188+
value=f"{root_dir}/resources/images/logo.png",
189+
label="GraphGen Banner",
190+
elem_id="banner",
191+
show_label=False,
192+
interactive=False,
193+
)
194+
gr.Markdown(
195+
"""
196+
This is a demo for the [GraphGen](https://github.com/open-sciencelab/GraphGen) project.
197+
GraphGen is a framework for synthetic data generation guided by knowledge graphs.
198+
"""
199+
)
200+
with gr.Row():
201+
# Model Configuration Column
202+
with gr.Column(scale=1):
203+
gr.Markdown("### Model Configuration")
204+
synthesizer_model = gr.Textbox(label="Synthesizer Model", value=config.get("SYNTHESIZER_MODEL", ""))
205+
synthesizer_base_url = gr.Textbox(label="Synthesizer Base URL", value=config.get("SYNTHESIZER_BASE_URL", ""))
206+
synthesizer_api_key = gr.Textbox(label="Synthesizer API Key", type="password", value=config.get("SYNTHESIZER_API_KEY", ""))
207+
trainee_model = gr.Textbox(label="Trainee Model", value=config.get("TRAINEE_MODEL", ""))
208+
trainee_base_url = gr.Textbox(label="Trainee Base URL", value=config.get("TRAINEE_BASE_URL", ""))
209+
trainee_api_key = gr.Textbox(label="Trainee API Key", type="password", value=config.get("TRAINEE_API_KEY", ""))
210+
test_connection_btn = gr.Button("Test Connection", variant="primary")
211+
gr.Button("Save Config", variant="primary")
212+
213+
# Input Configuration Column
214+
with gr.Column(scale=1):
215+
gr.Markdown("### Input Configuration")
216+
input_file = gr.Textbox(label="Input File Path", value="resources/examples/raw_demo.jsonl")
217+
data_type = gr.Radio(choices=["raw", "chunked"], label="Data Type", value="raw")
218+
qa_form = gr.Radio(choices=["atomic", "multi_hop", "open"], label="QA Form", value="multi_hop")
219+
tokenizer = gr.Textbox(label="Tokenizer", value="cl100k_base")
220+
web_search = gr.Checkbox(label="Enable Web Search", value=False)
221+
quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1)
222+
223+
# Traverse Strategy Column
224+
with gr.Column(scale=1):
225+
gr.Markdown("### Traverse Strategy")
226+
expand_method = gr.Radio(choices=["max_width", "max_tokens"], label="Expand Method", value="max_tokens")
227+
bidirectional = gr.Checkbox(label="Bidirectional", value=True)
228+
max_extra_edges = gr.Slider(minimum=1, maximum=10, value=5, label="Max Extra Edges", step=1)
229+
max_tokens = gr.Slider(minimum=64, maximum=1024, value=256, label="Max Tokens", step=64)
230+
max_depth = gr.Slider(minimum=1, maximum=5, value=2, label="Max Depth", step=1)
231+
edge_sampling = gr.Radio(choices=["max_loss", "min_loss", "random"], label="Edge Sampling",
232+
value="max_loss")
233+
isolated_node_strategy = gr.Radio(choices=["add", "ignore"], label="Isolated Node Strategy", value="ignore")
234+
difficulty_level = gr.Radio(choices=["easy", "medium", "hard"], label="Difficulty Level", value="medium")
235+
236+
# Submission and Output Rows
237+
with gr.Row():
238+
submit_btn = gr.Button("Run GraphGen")
239+
with gr.Row():
240+
output = gr.Textbox(label="Output")
241+
242+
# Test Connection
243+
test_connection_btn.click(
244+
test_api_connection,
245+
inputs=[synthesizer_base_url, synthesizer_api_key, synthesizer_model],
246+
outputs=output
247+
).then(
248+
test_api_connection,
249+
inputs=[trainee_base_url, trainee_api_key, trainee_model],
250+
outputs=output
251+
)
252+
253+
# Event Handling
254+
submit_btn.click(
255+
run_graphgen,
256+
inputs=[
257+
input_file, data_type, qa_form, tokenizer, web_search,
258+
expand_method, bidirectional, max_extra_edges, max_tokens,
259+
max_depth, edge_sampling, isolated_node_strategy, difficulty_level,
260+
synthesizer_model, synthesizer_base_url, synthesizer_api_key,
261+
trainee_model, trainee_base_url, trainee_api_key,
262+
quiz_samples
263+
],
264+
outputs=output
265+
)
219266

220267
if __name__ == "__main__":
221-
iface.launch()
268+
demo.launch()

webui/translation.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
en:
2+
Language: Language
3+
zh:
4+
Language: 语言

0 commit comments

Comments
 (0)