Skip to content

Commit f42f69b

Browse files
author
Roberto Montoya
committed
Merge remote-tracking branch 'origin/main' into feature/jab-api-0.3
# Conflicts: # fastchat/serve/api_provider.py
2 parents e27dd24 + d161b64 commit f42f69b

File tree

5 files changed

+184
-13
lines changed

5 files changed

+184
-13
lines changed

fastchat/model/model_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
24892489

24902490
class NoSystemAdapter(BaseModelAdapter):
24912491
def match(self, model_path: str):
2492-
keyword_list = ["athene-70b"]
2492+
keyword_list = ["athene-70b", "p2l"]
24932493

24942494
for keyword in keyword_list:
24952495
if keyword == model_path.lower():

fastchat/serve/api_provider.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_api_provider_stream_iter(
2323
top_p,
2424
max_new_tokens,
2525
state,
26+
extra_body=None,
2627
):
2728
if model_api_dict["api_type"] == "openai":
2829
if model_api_dict.get("vision-arena", False):
@@ -246,6 +247,18 @@ def get_api_provider_stream_iter(
246247
api_key=model_api_dict["api_key"],
247248
conversation_id=state.conv_id,
248249
)
250+
elif model_api_dict["api_type"] == "p2l":
251+
prompt = conv.to_openai_api_messages()
252+
stream_iter = p2l_api_stream_iter(
253+
model_api_dict["model_name"],
254+
prompt,
255+
temperature,
256+
top_p,
257+
max_new_tokens,
258+
api_base=model_api_dict["api_base"],
259+
api_key=model_api_dict["api_key"],
260+
extra_body=extra_body,
261+
)
249262
elif model_api_dict["api_type"] == "jab":
250263
messages = conv.to_jab_api_messages()
251264
stream_iter = jab_api_stream_iter(
@@ -421,6 +434,74 @@ def column_api_stream_iter(
421434
}
422435

423436

437+
def p2l_api_stream_iter(
438+
model_name,
439+
messages,
440+
temperature,
441+
top_p,
442+
max_new_tokens,
443+
api_base=None,
444+
api_key=None,
445+
extra_body=None,
446+
):
447+
import openai
448+
449+
client = openai.OpenAI(
450+
base_url=api_base,
451+
api_key=api_key or "-",
452+
timeout=180,
453+
)
454+
455+
# Make requests for logging
456+
text_messages = []
457+
for message in messages:
458+
if type(message["content"]) == str: # text-only model
459+
text_messages.append(message)
460+
else: # vision model
461+
filtered_content_list = [
462+
content for content in message["content"] if content["type"] == "text"
463+
]
464+
text_messages.append(
465+
{"role": message["role"], "content": filtered_content_list}
466+
)
467+
468+
gen_params = {
469+
"model": model_name,
470+
"prompt": text_messages,
471+
"temperature": None,
472+
"top_p": None,
473+
"max_new_tokens": max_new_tokens,
474+
"extra_body": extra_body,
475+
}
476+
logger.info(f"==== request ====\n{gen_params}")
477+
478+
res = client.chat.completions.create(
479+
model=model_name,
480+
messages=messages,
481+
max_tokens=max_new_tokens,
482+
stream=True,
483+
extra_body=extra_body,
484+
)
485+
text = ""
486+
for chunk_idx, chunk in enumerate(res):
487+
if len(chunk.choices) > 0:
488+
text += chunk.choices[0].delta.content or ""
489+
490+
data = {
491+
"text": text,
492+
"error_code": 0,
493+
}
494+
495+
if chunk_idx == 0:
496+
if hasattr(chunk.choices[0].delta, "model"):
497+
data["ans_model"] = chunk.choices[0].delta.model
498+
499+
if hasattr(chunk, "router_outputs"):
500+
data["router_outputs"] = chunk.router_outputs
501+
502+
yield data
503+
504+
424505
def upload_openai_file_to_gcs(file_id):
425506
import openai
426507
from google.cloud import storage

fastchat/serve/gradio_web_server.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
import time
1313
import uuid
14-
from typing import List
14+
from typing import List, Dict
1515

1616
import gradio as gr
1717
import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119
self.model_name = model_name
120120
self.oai_thread_id = None
121121
self.is_vision = is_vision
122+
self.ans_models = []
123+
self.router_outputs = []
122124

123125
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126
self.has_csam_image = False
@@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
128130
self.regen_support = False
129131
self.init_system_prompt(self.conv, is_vision)
130132

133+
def update_ans_models(self, ans: str) -> None:
134+
self.ans_models.append(ans)
135+
136+
def update_router_outputs(self, outputs: Dict[str, float]) -> None:
137+
self.router_outputs.append(outputs)
138+
131139
def init_system_prompt(self, conv, is_vision):
132140
system_prompt = conv.get_system_message(is_vision)
133141
if len(system_prompt) == 0:
@@ -154,6 +162,20 @@ def dict(self):
154162
}
155163
)
156164

165+
if self.ans_models:
166+
base.update(
167+
{
168+
"ans_models": self.ans_models,
169+
}
170+
)
171+
172+
if self.router_outputs:
173+
base.update(
174+
{
175+
"router_outputs": self.router_outputs,
176+
}
177+
)
178+
157179
if self.is_vision:
158180
base.update({"has_csam_image": self.has_csam_image})
159181
return base
@@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):
420442

421443

422444
def bot_response(
423-
state,
445+
state: State,
424446
temperature,
425447
top_p,
426448
max_new_tokens,
@@ -504,6 +526,8 @@ def bot_response(
504526
if not custom_system_prompt:
505527
conv.set_system_message("")
506528

529+
extra_body = None
530+
507531
if use_recommended_config:
508532
recommended_config = model_api_dict.get("recommended_config", None)
509533
if recommended_config is not None:
@@ -512,6 +536,7 @@ def bot_response(
512536
max_new_tokens = recommended_config.get(
513537
"max_new_tokens", max_new_tokens
514538
)
539+
extra_body = recommended_config.get("extra_body", None)
515540

516541
stream_iter = get_api_provider_stream_iter(
517542
conv,
@@ -521,6 +546,7 @@ def bot_response(
521546
top_p,
522547
max_new_tokens,
523548
state,
549+
extra_body=extra_body,
524550
)
525551

526552
html_code = ' <span class="cursor"></span> '
@@ -532,6 +558,18 @@ def bot_response(
532558
try:
533559
data = {"text": ""}
534560
for i, data in enumerate(stream_iter):
561+
# Change for P2L:
562+
if i == 0:
563+
if "ans_model" in data:
564+
ans_model = data.get("ans_model")
565+
566+
state.update_ans_models(ans_model)
567+
568+
if "router_outputs" in data:
569+
router_outputs = data.get("router_outputs")
570+
571+
state.update_router_outputs(router_outputs)
572+
535573
if data["error_code"] == 0:
536574
output = data["text"].strip()
537575
conv.update_last_message(output + "▌")
@@ -688,6 +726,22 @@ def bot_response(
688726
.block {
689727
overflow-y: hidden !important;
690728
}
729+
730+
.visualizer {
731+
overflow: hidden;
732+
height: 60vw;
733+
border: 1px solid lightgrey;
734+
border-radius: 10px;
735+
}
736+
737+
@media screen and (max-width: 769px) {
738+
.visualizer {
739+
height: 180vw;
740+
overflow-y: scroll;
741+
width: 100%;
742+
overflow-x: hidden;
743+
}
744+
}
691745
"""
692746

693747

fastchat/serve/gradio_web_server_multi.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
"""
55

66
import argparse
7-
import pickle
8-
import time
9-
from typing import List
10-
117
import gradio as gr
128

139
from fastchat.serve.gradio_block_arena_anony import (
@@ -54,6 +50,36 @@
5450
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
5551

5652

53+
def build_visualizer():
54+
visualizer_markdown = """
55+
# 🔍 Arena Visualizer
56+
This tool provides an interactive way to explore how people are using Chatbot Arena.
57+
Using *[topic clustering](https://github.com/MaartenGr/BERTopic)*, we organized user-submitted prompts from Arena battles into broad and specific categories.
58+
Dive in to uncover insights about the distribution and themes of these prompts!
59+
"""
60+
gr.Markdown(visualizer_markdown, elem_id="visualizer_markdown")
61+
expandText = "👇 Expand to see detailed instructions on how to use the visualizer"
62+
with gr.Accordion(expandText, open=False):
63+
instructions = """
64+
- Hover Over Segments: View the category name, the number of prompts, and their percentage.
65+
- *On mobile devices*: Tap instead of hover.
66+
- Click to Explore:
67+
- Click on a main category to see its subcategories.
68+
- Click on subcategories to see example prompts in the sidebar.
69+
- Undo and Reset: Click the center of the chart to return to the top level.
70+
71+
Visualizer is created using Arena battle data collected from 2024/6 to 2024/8.
72+
"""
73+
gr.Markdown(instructions)
74+
75+
frame = """
76+
<iframe class="visualizer" width="100%"
77+
src="https://storage.googleapis.com/public-arena-no-cors/index.html">
78+
</iframe>
79+
"""
80+
gr.HTML(frame)
81+
82+
5783
def load_demo(context: Context, request: gr.Request):
5884
ip = get_ip(request)
5985
logger.info(f"load_demo. ip: {ip}. params: {request.query_params}")
@@ -199,12 +225,14 @@ def build_demo(
199225
arena_hard_table,
200226
show_plot=True,
201227
)
228+
if args.show_visualizer:
229+
with gr.Tab("🔍 Arena Visualizer", id=5):
230+
build_visualizer()
202231

203232
with gr.Tab("ℹ️ About Us", id=4):
204-
about = build_about()
233+
build_about()
205234

206235
context_state = gr.State(context)
207-
url_params = gr.JSON(visible=False)
208236

209237
if args.model_list_mode not in ["once", "reload"]:
210238
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
@@ -271,7 +299,8 @@ def build_demo(
271299
parser.add_argument(
272300
"--gradio-auth-path",
273301
type=str,
274-
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
302+
help='Set the gradio authentication file path. The file should contain one or \
303+
more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
275304
default=None,
276305
)
277306
parser.add_argument(
@@ -286,7 +315,8 @@ def build_demo(
286315
parser.add_argument(
287316
"--gradio-root-path",
288317
type=str,
289-
help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix",
318+
help="Sets the gradio root path, eg /abc/def. Useful when running behind a \
319+
reverse-proxy or at a custom URL path prefix",
290320
)
291321
parser.add_argument(
292322
"--ga-id",
@@ -305,6 +335,12 @@ def build_demo(
305335
type=str,
306336
help="Set the password for the gradio web server",
307337
)
338+
parser.add_argument(
339+
"--show-visualizer",
340+
action="store_true",
341+
default=False,
342+
help="Show the Data Visualizer tab",
343+
)
308344
args = parser.parse_args()
309345
logger.info(f"args: {args}")
310346

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dependencies = [
1919
]
2020

2121
[project.optional-dependencies]
22-
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
23-
webui = ["gradio>=4.10"]
22+
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"]
23+
webui = ["gradio>=4.10", "plotly", "scipy"]
2424
train = ["einops", "flash-attn>=2.0", "wandb"]
2525
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
2626
dev = ["black==23.3.0", "pylint==2.8.2"]

0 commit comments

Comments
 (0)