Skip to content

Commit 7ccb851

Browse files
authored
Merge branch 'lm-sys:main' into fix-llama3.1_template
2 parents 97eca5d + d161b64 commit 7ccb851

File tree

5 files changed

+184
-13
lines changed

5 files changed

+184
-13
lines changed

fastchat/model/model_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2483,7 +2483,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
24832483

24842484
class NoSystemAdapter(BaseModelAdapter):
24852485
def match(self, model_path: str):
2486-
keyword_list = ["athene-70b"]
2486+
keyword_list = ["athene-70b", "p2l"]
24872487

24882488
for keyword in keyword_list:
24892489
if keyword == model_path.lower():

fastchat/serve/api_provider.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_api_provider_stream_iter(
2323
top_p,
2424
max_new_tokens,
2525
state,
26+
extra_body=None,
2627
):
2728
if model_api_dict["api_type"] == "openai":
2829
if model_api_dict.get("vision-arena", False):
@@ -246,6 +247,18 @@ def get_api_provider_stream_iter(
246247
api_key=model_api_dict["api_key"],
247248
conversation_id=state.conv_id,
248249
)
250+
elif model_api_dict["api_type"] == "p2l":
251+
prompt = conv.to_openai_api_messages()
252+
stream_iter = p2l_api_stream_iter(
253+
model_api_dict["model_name"],
254+
prompt,
255+
temperature,
256+
top_p,
257+
max_new_tokens,
258+
api_base=model_api_dict["api_base"],
259+
api_key=model_api_dict["api_key"],
260+
extra_body=extra_body,
261+
)
249262
else:
250263
raise NotImplementedError()
251264

@@ -412,6 +425,74 @@ def column_api_stream_iter(
412425
}
413426

414427

428+
def p2l_api_stream_iter(
429+
model_name,
430+
messages,
431+
temperature,
432+
top_p,
433+
max_new_tokens,
434+
api_base=None,
435+
api_key=None,
436+
extra_body=None,
437+
):
438+
import openai
439+
440+
client = openai.OpenAI(
441+
base_url=api_base,
442+
api_key=api_key or "-",
443+
timeout=180,
444+
)
445+
446+
# Make requests for logging
447+
text_messages = []
448+
for message in messages:
449+
if type(message["content"]) == str: # text-only model
450+
text_messages.append(message)
451+
else: # vision model
452+
filtered_content_list = [
453+
content for content in message["content"] if content["type"] == "text"
454+
]
455+
text_messages.append(
456+
{"role": message["role"], "content": filtered_content_list}
457+
)
458+
459+
gen_params = {
460+
"model": model_name,
461+
"prompt": text_messages,
462+
"temperature": None,
463+
"top_p": None,
464+
"max_new_tokens": max_new_tokens,
465+
"extra_body": extra_body,
466+
}
467+
logger.info(f"==== request ====\n{gen_params}")
468+
469+
res = client.chat.completions.create(
470+
model=model_name,
471+
messages=messages,
472+
max_tokens=max_new_tokens,
473+
stream=True,
474+
extra_body=extra_body,
475+
)
476+
text = ""
477+
for chunk_idx, chunk in enumerate(res):
478+
if len(chunk.choices) > 0:
479+
text += chunk.choices[0].delta.content or ""
480+
481+
data = {
482+
"text": text,
483+
"error_code": 0,
484+
}
485+
486+
if chunk_idx == 0:
487+
if hasattr(chunk.choices[0].delta, "model"):
488+
data["ans_model"] = chunk.choices[0].delta.model
489+
490+
if hasattr(chunk, "router_outputs"):
491+
data["router_outputs"] = chunk.router_outputs
492+
493+
yield data
494+
495+
415496
def upload_openai_file_to_gcs(file_id):
416497
import openai
417498
from google.cloud import storage

fastchat/serve/gradio_web_server.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
import time
1313
import uuid
14-
from typing import List
14+
from typing import List, Dict
1515

1616
import gradio as gr
1717
import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119
self.model_name = model_name
120120
self.oai_thread_id = None
121121
self.is_vision = is_vision
122+
self.ans_models = []
123+
self.router_outputs = []
122124

123125
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126
self.has_csam_image = False
@@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
128130
self.regen_support = False
129131
self.init_system_prompt(self.conv, is_vision)
130132

133+
def update_ans_models(self, ans: str) -> None:
134+
self.ans_models.append(ans)
135+
136+
def update_router_outputs(self, outputs: Dict[str, float]) -> None:
137+
self.router_outputs.append(outputs)
138+
131139
def init_system_prompt(self, conv, is_vision):
132140
system_prompt = conv.get_system_message(is_vision)
133141
if len(system_prompt) == 0:
@@ -154,6 +162,20 @@ def dict(self):
154162
}
155163
)
156164

165+
if self.ans_models:
166+
base.update(
167+
{
168+
"ans_models": self.ans_models,
169+
}
170+
)
171+
172+
if self.router_outputs:
173+
base.update(
174+
{
175+
"router_outputs": self.router_outputs,
176+
}
177+
)
178+
157179
if self.is_vision:
158180
base.update({"has_csam_image": self.has_csam_image})
159181
return base
@@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):
420442

421443

422444
def bot_response(
423-
state,
445+
state: State,
424446
temperature,
425447
top_p,
426448
max_new_tokens,
@@ -504,6 +526,8 @@ def bot_response(
504526
if not custom_system_prompt:
505527
conv.set_system_message("")
506528

529+
extra_body = None
530+
507531
if use_recommended_config:
508532
recommended_config = model_api_dict.get("recommended_config", None)
509533
if recommended_config is not None:
@@ -512,6 +536,7 @@ def bot_response(
512536
max_new_tokens = recommended_config.get(
513537
"max_new_tokens", max_new_tokens
514538
)
539+
extra_body = recommended_config.get("extra_body", None)
515540

516541
stream_iter = get_api_provider_stream_iter(
517542
conv,
@@ -521,6 +546,7 @@ def bot_response(
521546
top_p,
522547
max_new_tokens,
523548
state,
549+
extra_body=extra_body,
524550
)
525551

526552
html_code = ' <span class="cursor"></span> '
@@ -532,6 +558,18 @@ def bot_response(
532558
try:
533559
data = {"text": ""}
534560
for i, data in enumerate(stream_iter):
561+
# Change for P2L:
562+
if i == 0:
563+
if "ans_model" in data:
564+
ans_model = data.get("ans_model")
565+
566+
state.update_ans_models(ans_model)
567+
568+
if "router_outputs" in data:
569+
router_outputs = data.get("router_outputs")
570+
571+
state.update_router_outputs(router_outputs)
572+
535573
if data["error_code"] == 0:
536574
output = data["text"].strip()
537575
conv.update_last_message(output + "▌")
@@ -688,6 +726,22 @@ def bot_response(
688726
.block {
689727
overflow-y: hidden !important;
690728
}
729+
730+
.visualizer {
731+
overflow: hidden;
732+
height: 60vw;
733+
border: 1px solid lightgrey;
734+
border-radius: 10px;
735+
}
736+
737+
@media screen and (max-width: 769px) {
738+
.visualizer {
739+
height: 180vw;
740+
overflow-y: scroll;
741+
width: 100%;
742+
overflow-x: hidden;
743+
}
744+
}
691745
"""
692746

693747

fastchat/serve/gradio_web_server_multi.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
"""
55

66
import argparse
7-
import pickle
8-
import time
9-
from typing import List
10-
117
import gradio as gr
128

139
from fastchat.serve.gradio_block_arena_anony import (
@@ -54,6 +50,36 @@
5450
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
5551

5652

53+
def build_visualizer():
54+
visualizer_markdown = """
55+
# 🔍 Arena Visualizer
56+
This tool provides an interactive way to explore how people are using Chatbot Arena.
57+
Using *[topic clustering](https://github.com/MaartenGr/BERTopic)*, we organized user-submitted prompts from Arena battles into broad and specific categories.
58+
Dive in to uncover insights about the distribution and themes of these prompts!
59+
"""
60+
gr.Markdown(visualizer_markdown, elem_id="visualizer_markdown")
61+
expandText = "👇 Expand to see detailed instructions on how to use the visualizer"
62+
with gr.Accordion(expandText, open=False):
63+
instructions = """
64+
- Hover Over Segments: View the category name, the number of prompts, and their percentage.
65+
- *On mobile devices*: Tap instead of hover.
66+
- Click to Explore:
67+
- Click on a main category to see its subcategories.
68+
- Click on subcategories to see example prompts in the sidebar.
69+
- Undo and Reset: Click the center of the chart to return to the top level.
70+
71+
Visualizer is created using Arena battle data collected from 2024/6 to 2024/8.
72+
"""
73+
gr.Markdown(instructions)
74+
75+
frame = """
76+
<iframe class="visualizer" width="100%"
77+
src="https://storage.googleapis.com/public-arena-no-cors/index.html">
78+
</iframe>
79+
"""
80+
gr.HTML(frame)
81+
82+
5783
def load_demo(context: Context, request: gr.Request):
5884
ip = get_ip(request)
5985
logger.info(f"load_demo. ip: {ip}. params: {request.query_params}")
@@ -199,12 +225,14 @@ def build_demo(
199225
arena_hard_table,
200226
show_plot=True,
201227
)
228+
if args.show_visualizer:
229+
with gr.Tab("🔍 Arena Visualizer", id=5):
230+
build_visualizer()
202231

203232
with gr.Tab("ℹ️ About Us", id=4):
204-
about = build_about()
233+
build_about()
205234

206235
context_state = gr.State(context)
207-
url_params = gr.JSON(visible=False)
208236

209237
if args.model_list_mode not in ["once", "reload"]:
210238
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
@@ -271,7 +299,8 @@ def build_demo(
271299
parser.add_argument(
272300
"--gradio-auth-path",
273301
type=str,
274-
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
302+
help='Set the gradio authentication file path. The file should contain one or \
303+
more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
275304
default=None,
276305
)
277306
parser.add_argument(
@@ -286,7 +315,8 @@ def build_demo(
286315
parser.add_argument(
287316
"--gradio-root-path",
288317
type=str,
289-
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
318+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a \
319+
reverse-proxy or at a custom URL path prefix",
290320
)
291321
parser.add_argument(
292322
"--ga-id",
@@ -305,6 +335,12 @@ def build_demo(
305335
type=str,
306336
help="Set the password for the gradio web server",
307337
)
338+
parser.add_argument(
339+
"--show-visualizer",
340+
action="store_true",
341+
default=False,
342+
help="Show the Data Visualizer tab",
343+
)
308344
args = parser.parse_args()
309345
logger.info(f"args: {args}")
310346

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dependencies = [
1919
]
2020

2121
[project.optional-dependencies]
22-
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
23-
webui = ["gradio>=4.10"]
22+
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"]
23+
webui = ["gradio>=4.10", "plotly", "scipy"]
2424
train = ["einops", "flash-attn>=2.0", "wandb"]
2525
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
2626
dev = ["black==23.3.0", "pylint==2.8.2"]

0 commit comments

Comments
 (0)