Skip to content

Commit 8891f35

Browse files
authored
Merge pull request #23 from Mrw33554432/main
Add a simple gradio interface, make life easier
2 parents c4e0c63 + 0c55d52 commit 8891f35

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

simple_gradio_interface.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import gradio as gr
2+
import time
3+
import torch
4+
from medusa.model.medusa_model import MedusaModel
5+
from fastchat.model.model_adapter import get_conversation_template
6+
7+
# Global variables
8+
chat_history = ""
9+
model = None
10+
tokenizer = None
11+
conv = None
12+
13+
14+
def load_model_function(model_name, load_in_8bit=False, load_in_4bit=False):
15+
model_name = model_name or "FasterDecoding/medusa-vicuna-7b-v1.3"
16+
global model, tokenizer, conv
17+
18+
try:
19+
model = MedusaModel.from_pretrained(
20+
model_name,
21+
torch_dtype=torch.float16,
22+
low_cpu_mem_usage=True,
23+
device_map="auto",
24+
load_in_8bit=load_in_8bit,
25+
load_in_4bit=load_in_4bit
26+
)
27+
tokenizer = model.get_tokenizer()
28+
conv = get_conversation_template("vicuna")
29+
return "Model loaded successfully!"
30+
except:
31+
return "Error loading the model. Please check the model name and try again."
32+
33+
34+
def reset_conversation():
35+
"""
36+
Reset the global conversation and chat history
37+
"""
38+
global conv, chat_history
39+
conv = get_conversation_template("vicuna")
40+
chat_history = ""
41+
42+
43+
def medusa_chat_interface(user_input, temperature, max_steps, no_history):
44+
global model, tokenizer, conv, chat_history
45+
46+
# Reset the conversation if no_history is checked
47+
if no_history:
48+
reset_conversation()
49+
50+
if not model or not tokenizer:
51+
return "Error: Model not loaded!", chat_history
52+
53+
chat_history += "\nYou: " + user_input
54+
conv.append_message(conv.roles[0], user_input)
55+
conv.append_message(conv.roles[1], '')
56+
prompt = conv.get_prompt()
57+
58+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device)
59+
60+
outputs = model.medusa_generate(input_ids, temperature=temperature, max_steps=max_steps)
61+
response = ""
62+
for output in outputs:
63+
response = output['text']
64+
yield response, chat_history
65+
time.sleep(0.01)
66+
67+
chat_history += "\nMedusa: " + response.strip()
68+
69+
return response, chat_history
70+
71+
72+
if __name__ == "__main__":
73+
load_model_interface = gr.Interface(
74+
load_model_function,
75+
[
76+
gr.components.Textbox(placeholder="FasterDecoding/medusa-vicuna-7b-v1.3", label="Model Name"),
77+
gr.components.Checkbox(label="Use 8-bit Quantization"),
78+
gr.components.Checkbox(label="Use 4-bit Quantization"),
79+
],
80+
gr.components.Textbox(label="Model Load Status", type="text"),
81+
description="Load Medusa Model",
82+
title="Medusa Model Loader",
83+
live=False,
84+
api_name="load_model"
85+
)
86+
87+
# Chat Interface
88+
chat_interface = gr.Interface(
89+
medusa_chat_interface,
90+
[
91+
gr.components.Textbox(placeholder="Ask Medusa...", label="User Input"),
92+
gr.components.Slider(minimum=0, maximum=1.5, label="Temperature"),
93+
gr.components.Slider(minimum=50, maximum=1000, label="Max Steps"),
94+
gr.components.Checkbox(label="No History"),
95+
],
96+
[
97+
gr.components.Textbox(label="Medusa's Response", type="text"),
98+
gr.components.Textbox(label="Chat History", type="text")
99+
],
100+
live=False,
101+
description="Chat with Medusa",
102+
title="Medusa Chatbox",
103+
api_name="chat"
104+
)
105+
106+
# Combine the interfaces in a TabbedInterface
107+
combined_interface = gr.TabbedInterface([load_model_interface, chat_interface],
108+
["Load Model", "Chat"])
109+
110+
# Launch the combined interface
111+
combined_interface.queue().launch()

0 commit comments

Comments
 (0)