Skip to content

Commit b612600

Browse files
feat(webui): add test_api func
1 parent d47e504 commit b612600

File tree

3 files changed

+88
-51
lines changed

3 files changed

+88
-51
lines changed

webui/app.py

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def run_graphgen(
160160
return f"Error occurred: {str(e)}"
161161

162162
# Create Gradio interface
163-
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Citrus(), css=css) as demo:
163+
with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Base(), css=css) as demo:
164164
# Header
165165
gr.Image(
166166
value=f"{root_dir}/resources/images/logo.png",
@@ -212,32 +212,23 @@ def run_graphgen(
212212
"### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro")
213213
)
214214

215+
# Model Configuration Column
216+
with gr.Accordion(label=_("Model Config"), open=False):
217+
base_url = gr.Textbox(label="Base URL", value="https://api.siliconflow.cn/v1",
218+
info=_("Base URL Info"), interactive=True)
219+
synthesizer_model = gr.Textbox(label="Synthesizer Model", value="Qwen/Qwen2.5-72B-Instruct",
220+
info=_("Synthesizer Model Info"), interactive=True)
221+
trainee_model = gr.Textbox(label="Trainee Model", value="Qwen/Qwen2.5-7B-Instruct",
222+
info=_("Trainee Model Info"), interactive=True)
223+
224+
with gr.Accordion(label=_("Generation Config"), open=False):
225+
qa_form = gr.Radio(choices=["atomic", "multi_hop", "open"], label="QA Form",
226+
value="multi_hop", interactive=True)
227+
tokenizer = gr.Textbox(label="Tokenizer", value="cl100k_base")
228+
# web_search = gr.Checkbox(label="Enable Web Search", value=False)
229+
# quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1)
215230

216-
with gr.Row():
217-
# Model Configuration Column
218231
with gr.Column(scale=1):
219-
gr.Markdown("### Model Configuration")
220-
synthesizer_model = gr.Textbox(label="Synthesizer Model", value="")
221-
synthesizer_base_url = gr.Textbox(label="Synthesizer Base URL", value="")
222-
synthesizer_api_key = gr.Textbox(label="Synthesizer API Key", type="password", value="")
223-
trainee_model = gr.Textbox(label="Trainee Model", value="")
224-
trainee_base_url = gr.Textbox(label="Trainee Base URL", value="")
225-
trainee_api_key = gr.Textbox(label="Trainee API Key", type="password", value="")
226-
test_connection_btn = gr.Button("Test Connection")
227-
228-
# Input Configuration Column
229-
with gr.Column(scale=1):
230-
gr.Markdown("### Input Configuration")
231-
input_file = gr.Textbox(label="Input File Path", value="resources/examples/raw_demo.jsonl")
232-
data_type = gr.Radio(choices=["raw", "chunked"], label="Data Type", value="raw")
233-
qa_form = gr.Radio(choices=["atomic", "multi_hop", "open"], label="QA Form", value="multi_hop")
234-
tokenizer = gr.Textbox(label="Tokenizer", value="cl100k_base")
235-
web_search = gr.Checkbox(label="Enable Web Search", value=False)
236-
quiz_samples = gr.Number(label="Quiz Samples", value=2, minimum=1)
237-
238-
# Traverse Strategy Column
239-
with gr.Column(scale=1):
240-
gr.Markdown("### Traverse Strategy")
241232
expand_method = gr.Radio(choices=["max_width", "max_tokens"], label="Expand Method", value="max_tokens")
242233
bidirectional = gr.Checkbox(label="Bidirectional", value=True)
243234
max_extra_edges = gr.Slider(minimum=1, maximum=10, value=5, label="Max Extra Edges", step=1)
@@ -248,36 +239,56 @@ def run_graphgen(
248239
isolated_node_strategy = gr.Radio(choices=["add", "ignore"], label="Isolated Node Strategy", value="ignore")
249240
difficulty_level = gr.Radio(choices=["easy", "medium", "hard"], label="Difficulty Level", value="medium")
250241

251-
# Submission and Output Rows
252-
with gr.Row():
253-
submit_btn = gr.Button("Run GraphGen")
254-
with gr.Row():
255-
output = gr.Textbox(label="Output")
242+
with gr.Row(equal_height=True):
243+
with gr.Column(scale=3):
244+
api_key = gr.Textbox(label="SiliconFlow API Key", type="password", value="")
245+
with gr.Column(scale=1):
246+
test_connection_btn = gr.Button("Test Connection")
247+
248+
with gr.Blocks():
249+
with gr.Row(equal_height=True):
250+
with gr.Column(scale=1):
251+
upload_btn = gr.File(
252+
label="Upload File",
253+
file_count="multiple",
254+
value=["translation.json", "translation.json"],
255+
interactive=True,
256+
)
257+
with gr.Column(scale=1):
258+
download_file = gr.File(
259+
label="Output",
260+
file_count="single",
261+
interactive=False,
262+
)
263+
264+
submit_btn = gr.Button("Run GraphGen")
256265

257266
# Test Connection
258267
test_connection_btn.click(
259268
test_api_connection,
260-
inputs=[synthesizer_base_url, synthesizer_api_key, synthesizer_model],
261-
outputs=output
262-
).then(
269+
inputs=[base_url, api_key, synthesizer_model],
270+
outputs=[]
271+
)
272+
273+
test_connection_btn.click(
263274
test_api_connection,
264-
inputs=[trainee_base_url, trainee_api_key, trainee_model],
265-
outputs=output
275+
inputs=[base_url, api_key, trainee_model],
276+
outputs=[]
266277
)
267278

268279
# Event Handling
269-
submit_btn.click(
270-
run_graphgen,
271-
inputs=[
272-
input_file, data_type, qa_form, tokenizer, web_search,
273-
expand_method, bidirectional, max_extra_edges, max_tokens,
274-
max_depth, edge_sampling, isolated_node_strategy, difficulty_level,
275-
synthesizer_model, synthesizer_base_url, synthesizer_api_key,
276-
trainee_model, trainee_base_url, trainee_api_key,
277-
quiz_samples
278-
],
279-
outputs=output
280-
)
280+
# submit_btn.click(
281+
# run_graphgen,
282+
# inputs=[
283+
# input_file, data_type, qa_form, tokenizer, web_search,
284+
# expand_method, bidirectional, max_extra_edges, max_tokens,
285+
# max_depth, edge_sampling, isolated_node_strategy, difficulty_level,
286+
# synthesizer_model, synthesizer_base_url, synthesizer_api_key,
287+
# trainee_model, trainee_base_url, trainee_api_key,
288+
# quiz_samples
289+
# ],
290+
# outputs=output
291+
# )
281292

282293
if __name__ == "__main__":
283294
demo.launch()

webui/test_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from openai import OpenAI
2+
import gradio as gr
3+
4+
def test_api_connection(api_base, api_key, model_name):
5+
client = OpenAI(api_key=api_key, base_url=api_base)
6+
try:
7+
response = client.chat.completions.create(
8+
model=model_name,
9+
messages=[{"role": "user", "content": "test"}],
10+
max_tokens=1
11+
)
12+
if not response.choices or not response.choices[0].message:
13+
gr.Error(f"{model_name}: Invalid response from API")
14+
gr.Success(f"{model_name}: API connection successful")
15+
except Exception as e:
16+
gr.Error(f"{model_name}: API connection failed: {str(e)}")

webui/translation.json

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
{
22
"en": {
3-
"Title": "Revolutionizing Data Generation for LLMs",
4-
"Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation."
3+
"Title": "Revolutionizing Training Data Generation for LLMs",
4+
"Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation.",
5+
"Base URL Info": "Base URL for the API, use SiliconFlow as default",
6+
"Synthesizer Model Info": "Synthesizer Model Info",
7+
"Trainee Model Info": "Model for training",
8+
"Model Config": "Model Configuration",
9+
"Generation Config": "Generation Config"
510
},
611
"zh": {
7-
"Title": "革新LLM数据生成",
8-
"Intro": "是一个基于知识图谱的合成数据生成框架,旨在解决知识密集型问答生成的挑战。"
12+
"Title": "革新LLM训练数据生成",
13+
"Intro": "是一个基于知识图谱的合成数据生成框架,旨在解决知识密集型问答生成的挑战。",
14+
"Base URL Info": "调用模型API的URL,默认使用硅基流动",
15+
"Synthesizer Model Info": "Synthesizer Model Info",
16+
"Trainee Model Info": "用于训练的模型",
17+
"Model Config": "模型配置",
18+
"Generation Config": "生成配置"
919
}
1020
}

0 commit comments

Comments
 (0)