Skip to content

Commit b27b6fd

Browse files
committed
Add a simple gradio interface, make life easier
1 parent c4e0c63 commit b27b6fd

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

simple_gradio_interface.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import gradio as gr
2+
import time
3+
import torch
4+
from medusa.model.medusa_model import MedusaModel
5+
from fastchat.model.model_adapter import get_conversation_template
6+
7+
# Global variable to store the chat history
8+
chat_history = ""
9+
10+
11+
def medusa_chat_interface(user_input):
12+
global model, tokenizer, conv, chat_history
13+
14+
# Add user's input to chat history
15+
chat_history += "\nYou: " + user_input
16+
17+
# Process the user input and get the model's response
18+
conv.append_message(conv.roles[0], user_input)
19+
conv.append_message(conv.roles[1], '') # Placeholder for the Medusa response
20+
prompt = conv.get_prompt()
21+
print(prompt)
22+
23+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device)
24+
25+
outputs = model.medusa_generate(input_ids, temperature=0.7, max_steps=512)
26+
response = ""
27+
for output in outputs:
28+
response = output['text']
29+
# Send the current response to the output box
30+
yield response, chat_history
31+
time.sleep(0.01)
32+
33+
# Update chat history with the complete Medusa's response after the loop
34+
chat_history += "\nMedusa: " + response.strip()
35+
36+
return response, chat_history
37+
38+
39+
if __name__ == "__main__":
40+
MODEL_PATH = "FasterDecoding/medusa-vicuna-7b-v1.3"
41+
model = MedusaModel.from_pretrained(
42+
MODEL_PATH,
43+
torch_dtype=torch.float16,
44+
low_cpu_mem_usage=True,
45+
device_map="auto"
46+
)
47+
tokenizer = model.get_tokenizer()
48+
conv = get_conversation_template("vicuna")
49+
50+
interface = gr.Interface(
51+
medusa_chat_interface,
52+
gr.components.Textbox(placeholder="Ask Medusa..."),
53+
[gr.components.Textbox(label="Medusa's Response", type="text"),
54+
gr.components.Textbox(label="Chat History", type="text")],
55+
live=False,
56+
description="Chat with Medusa",
57+
title="Medusa Chatbox"
58+
)
59+
interface.queue().launch()

0 commit comments

Comments
 (0)