Skip to content

Commit 4297f93

Browse files
committed
兼容openai api
1 parent 7cffe55 commit 4297f93

File tree

2 files changed

+421
-154
lines changed

2 files changed

+421
-154
lines changed

scripts/gradio_demo.py

Lines changed: 100 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1+
import subprocess
12
import time
23
import gradio as gr
4+
from openai import OpenAI
35
import requests
46
import json
7+
import re
58

69
# Base URL of your API server; adjust host and port as needed
7-
API_URL = "http://10.126.33.142:8000"
10+
API_URL = "http://0.0.0.0:8000/v1"
11+
MODEL = "AXERA-TECH/Qwen3-1.7B"
12+
13+
def get_all_local_ips():
14+
result = subprocess.run(['ip', 'a'], capture_output=True, text=True)
15+
output = result.stdout
16+
17+
# 匹配所有IPv4
18+
ips = re.findall(r'inet (\d+\.\d+\.\d+\.\d+)', output)
19+
20+
# 过滤掉回环地址
21+
real_ips = [ip for ip in ips if not ip.startswith('127.')]
22+
23+
return real_ips
824

925

1026
def reset_chat(system_prompt):
@@ -17,7 +33,7 @@ def reset_chat(system_prompt):
1733
if system_prompt:
1834
payload["system_prompt"] = system_prompt
1935
try:
20-
response = requests.post(f"{API_URL}/api/reset", json=payload)
36+
response = requests.post(f"{API_URL}/reset", json=payload)
2137
response.raise_for_status()
2238
except Exception as e:
2339
# Return error in chat if reset fails
@@ -26,68 +42,77 @@ def reset_chat(system_prompt):
2642
return [], ""
2743

2844

29-
def stream_generate(history, message, temperature, repetition_penalty, top_p, top_k):
30-
"""
31-
Sends the user message and sampling parameters to /api/generate.
32-
Streams the response chunks and updates the last bot message in history.
33-
Clears input after sending. On error, shows error in chat.
34-
"""
35-
history = history + [(message, "")]
36-
yield history, ""
37-
payload = {
38-
"prompt": message,
39-
"temperature": temperature,
40-
"repetition_penalty": repetition_penalty,
41-
"top-p": top_p,
42-
"top-k": top_k
43-
}
45+
def build_messages(prompt: str):
46+
content = []
47+
if prompt and prompt.strip():
48+
content.append({"type": "text", "text": prompt.strip()})
49+
50+
return {"role": "user", "content": content if content else [{"type": "text", "text": prompt or ""}]}
51+
52+
# ---------- Gradio callback (single-turn, stream) ----------
53+
def run_single_turn(prompt, chatbot_state):
4454
try:
45-
response = requests.post(f"{API_URL}/api/generate", json=payload, timeout=(3.05, None))
46-
response.raise_for_status()
55+
# 清空历史(单轮),构造用户气泡
56+
# chatbot_state = []
57+
58+
# 构造 messages 和预览
59+
messages = build_messages(
60+
prompt=prompt or "",
61+
)
62+
63+
user_md = (prompt or "").strip()
64+
65+
chatbot_state.append((user_md or "(空提示)", "")) # assistant 先空字符串,等待流式填充
66+
yield chatbot_state, chatbot_state # 先把用户气泡渲染出来
67+
68+
# 调后端(流式)
69+
client = OpenAI(api_key="not-needed", base_url=API_URL.strip())
70+
stream = client.chat.completions.create(
71+
model=MODEL.strip(),
72+
messages=messages,
73+
stream=True,
74+
)
75+
76+
bot_chunks = []
77+
# 先补一个空 assistant 气泡
78+
# if len(chatbot_state) == 1:
79+
chatbot_state[-1] = (chatbot_state[-1][0], "")
80+
yield chatbot_state, chatbot_state
81+
82+
# 逐 chunk 更新 assistant 气泡(Markdown)
83+
for ev in stream:
84+
delta = getattr(ev.choices[0], "delta", None)
85+
if delta and getattr(delta, "content", None):
86+
ctx = delta.content
87+
if "<think>" in delta.content:
88+
ctx = delta.content.replace("<think>", "【思考中】")
89+
90+
if "</think>" in delta.content:
91+
ctx = delta.content.replace("</think>", "【思考结束】")
92+
93+
bot_chunks.append(ctx)
94+
chatbot_state[-1] = (chatbot_state[-1][0], "".join(bot_chunks))
95+
yield chatbot_state, chatbot_state
96+
97+
# 结束再确保收尾
98+
chatbot_state[-1] = (chatbot_state[-1][0], "".join(bot_chunks) if bot_chunks else "(empty response)")
99+
yield chatbot_state, chatbot_state
100+
47101
except Exception as e:
48-
history[-1] = (message, f"Error: {str(e)}")
49-
yield history, ""
50-
return
51-
time.sleep(0.1)
52-
53-
while True:
54-
time.sleep(0.01)
55-
response = requests.get(
56-
f"{API_URL}/api/generate_provider"
57-
)
58-
data = response.json()
59-
chunk:str = data.get("response", "")
60-
done = data.get("done", False)
61-
if done:
62-
break
63-
if chunk.strip() == "":
64-
continue
65-
history[-1] = (message, history[-1][1] + chunk)
66-
yield history, ""
67-
68-
print("end")
69-
102+
chatbot_state.append((
103+
chatbot_state[-1][0] if chatbot_state else "(request)",
104+
f"**Error:** {e}"
105+
))
106+
yield chatbot_state, chatbot_state
107+
108+
70109

71110
def stop_generate():
72111
try:
73-
requests.get(f"{API_URL}/api/stop")
112+
requests.get(f"{API_URL}/stop")
74113
except Exception as e:
75114
print(e)
76115

77-
# Build the Gradio interface optimized for PC with spacious layout
78-
# custom_css = """
79-
# .gradio-container {
80-
# max-width: 1400px;
81-
# margin: auto;
82-
# padding: 20px;
83-
# }
84-
# .gradio-container > * {
85-
# margin-bottom: 20px;
86-
# }
87-
# #chatbox .overflow-y-auto {
88-
# height: 600px !important;
89-
# }
90-
# """
91116

92117
# Build the Gradio interface优化布局
93118
with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
@@ -111,26 +136,32 @@ def stop_generate():
111136
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
112137
top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
113138
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")
114-
115-
# Wire up events: reset clears chat and input
116-
reset_button.click(fn=reset_chat, inputs=system_prompt, outputs=[chatbot, user_input])
117-
# send streams chat and clears input
139+
140+
141+
chat_state = gr.State([])
142+
143+
reset_button.click(
144+
fn=reset_chat,
145+
inputs=system_prompt,
146+
outputs=[chatbot, user_input],
147+
).then(
148+
lambda: [],
149+
inputs=None,
150+
outputs=chat_state
151+
)
152+
118153
send_button.click(
119-
fn=stream_generate,
120-
inputs=[chatbot, user_input, temperature, repetition_penalty, top_p, top_k],
121-
outputs=[chatbot, user_input]
154+
fn=run_single_turn,
155+
inputs=[user_input, chat_state],
156+
outputs=[chatbot, chat_state],
157+
show_progress=True,
158+
queue=True,
122159
)
123160

124161
stop_button.click(
125162
fn=stop_generate
126163
)
127164

128-
# allow Enter key to send
129-
user_input.submit(
130-
fn=stream_generate,
131-
inputs=[chatbot, user_input, temperature, repetition_penalty, top_p, top_k],
132-
outputs=[chatbot, user_input]
133-
)
134165

135166
if __name__ == "__main__":
136167
demo.launch(server_name="0.0.0.0", server_port=7860) # adjust as needed

0 commit comments

Comments
 (0)