Skip to content

Commit 0c55d52

Browse files
committed
add a model loader page and some settings.
1 parent b27b6fd commit 0c55d52

File tree

1 file changed

+76
-24
lines changed

1 file changed

+76
-24
lines changed

simple_gradio_interface.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,108 @@
44
from medusa.model.medusa_model import MedusaModel
55
from fastchat.model.model_adapter import get_conversation_template
66

7-
# Global variable to store the chat history
7+
# Global variables
88
chat_history = ""
9+
model = None
10+
tokenizer = None
11+
conv = None
912

1013

11-
def medusa_chat_interface(user_input):
14+
def load_model_function(model_name, load_in_8bit=False, load_in_4bit=False):
15+
model_name = model_name or "FasterDecoding/medusa-vicuna-7b-v1.3"
16+
global model, tokenizer, conv
17+
18+
try:
19+
model = MedusaModel.from_pretrained(
20+
model_name,
21+
torch_dtype=torch.float16,
22+
low_cpu_mem_usage=True,
23+
device_map="auto",
24+
load_in_8bit=load_in_8bit,
25+
load_in_4bit=load_in_4bit
26+
)
27+
tokenizer = model.get_tokenizer()
28+
conv = get_conversation_template("vicuna")
29+
return "Model loaded successfully!"
30+
except:
31+
return "Error loading the model. Please check the model name and try again."
32+
33+
34+
def reset_conversation():
35+
"""
36+
Reset the global conversation and chat history
37+
"""
38+
global conv, chat_history
39+
conv = get_conversation_template("vicuna")
40+
chat_history = ""
41+
42+
43+
def medusa_chat_interface(user_input, temperature, max_steps, no_history):
1244
global model, tokenizer, conv, chat_history
1345

14-
# Add user's input to chat history
15-
chat_history += "\nYou: " + user_input
46+
# Reset the conversation if no_history is checked
47+
if no_history:
48+
reset_conversation()
1649

17-
# Process the user input and get the model's response
50+
if not model or not tokenizer:
51+
return "Error: Model not loaded!", chat_history
52+
53+
chat_history += "\nYou: " + user_input
1854
conv.append_message(conv.roles[0], user_input)
19-
conv.append_message(conv.roles[1], '') # Placeholder for the Medusa response
55+
conv.append_message(conv.roles[1], '')
2056
prompt = conv.get_prompt()
21-
print(prompt)
2257

2358
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device)
2459

25-
outputs = model.medusa_generate(input_ids, temperature=0.7, max_steps=512)
60+
outputs = model.medusa_generate(input_ids, temperature=temperature, max_steps=max_steps)
2661
response = ""
2762
for output in outputs:
2863
response = output['text']
29-
# Send the current response to the output box
3064
yield response, chat_history
3165
time.sleep(0.01)
3266

33-
# Update chat history with the complete Medusa's response after the loop
3467
chat_history += "\nMedusa: " + response.strip()
3568

3669
return response, chat_history
3770

3871

3972
if __name__ == "__main__":
40-
MODEL_PATH = "FasterDecoding/medusa-vicuna-7b-v1.3"
41-
model = MedusaModel.from_pretrained(
42-
MODEL_PATH,
43-
torch_dtype=torch.float16,
44-
low_cpu_mem_usage=True,
45-
device_map="auto"
73+
load_model_interface = gr.Interface(
74+
load_model_function,
75+
[
76+
gr.components.Textbox(placeholder="FasterDecoding/medusa-vicuna-7b-v1.3", label="Model Name"),
77+
gr.components.Checkbox(label="Use 8-bit Quantization"),
78+
gr.components.Checkbox(label="Use 4-bit Quantization"),
79+
],
80+
gr.components.Textbox(label="Model Load Status", type="text"),
81+
description="Load Medusa Model",
82+
title="Medusa Model Loader",
83+
live=False,
84+
api_name="load_model"
4685
)
47-
tokenizer = model.get_tokenizer()
48-
conv = get_conversation_template("vicuna")
4986

50-
interface = gr.Interface(
87+
# Chat Interface
88+
chat_interface = gr.Interface(
5189
medusa_chat_interface,
52-
gr.components.Textbox(placeholder="Ask Medusa..."),
53-
[gr.components.Textbox(label="Medusa's Response", type="text"),
54-
gr.components.Textbox(label="Chat History", type="text")],
90+
[
91+
gr.components.Textbox(placeholder="Ask Medusa...", label="User Input"),
92+
gr.components.Slider(minimum=0, maximum=1.5, label="Temperature"),
93+
gr.components.Slider(minimum=50, maximum=1000, label="Max Steps"),
94+
gr.components.Checkbox(label="No History"),
95+
],
96+
[
97+
gr.components.Textbox(label="Medusa's Response", type="text"),
98+
gr.components.Textbox(label="Chat History", type="text")
99+
],
55100
live=False,
56101
description="Chat with Medusa",
57-
title="Medusa Chatbox"
102+
title="Medusa Chatbox",
103+
api_name="chat"
58104
)
59-
interface.queue().launch()
105+
106+
# Combine the interfaces in a TabbedInterface
107+
combined_interface = gr.TabbedInterface([load_model_interface, chat_interface],
108+
["Load Model", "Chat"])
109+
110+
# Launch the combined interface
111+
combined_interface.queue().launch()

0 commit comments

Comments
 (0)