Skip to content

Commit 1e4c97a

Browse files
committed
p2l stuff
1 parent 8664268 commit 1e4c97a

File tree

5 files changed

+143
-6
lines changed

5 files changed

+143
-6
lines changed

fastchat/model/model_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
24892489

24902490
class NoSystemAdapter(BaseModelAdapter):
24912491
def match(self, model_path: str):
2492-
keyword_list = ["athene-70b"]
2492+
keyword_list = ["athene-70b", "p2l"]
24932493

24942494
for keyword in keyword_list:
24952495
if keyword == model_path.lower():

fastchat/serve/api_provider.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,17 @@ def get_api_provider_stream_iter(
246246
api_key=model_api_dict["api_key"],
247247
conversation_id=state.conv_id,
248248
)
249+
elif model_api_dict["api_type"] == "p2l":
250+
prompt = conv.to_openai_api_messages()
251+
stream_iter = p2l_api_stream_iter(
252+
model_api_dict["model_name"],
253+
prompt,
254+
temperature,
255+
top_p,
256+
max_new_tokens,
257+
api_base=model_api_dict["api_base"],
258+
api_key=model_api_dict["api_key"],
259+
)
249260
else:
250261
raise NotImplementedError()
251262

@@ -412,6 +423,72 @@ def column_api_stream_iter(
412423
}
413424

414425

426+
def p2l_api_stream_iter(
427+
model_name,
428+
messages,
429+
temperature,
430+
top_p,
431+
max_new_tokens,
432+
api_base=None,
433+
api_key=None,
434+
):
435+
import openai
436+
437+
client = openai.OpenAI(
438+
base_url=api_base,
439+
api_key=api_key or "-",
440+
timeout=180,
441+
)
442+
443+
# Make requests for logging
444+
text_messages = []
445+
for message in messages:
446+
if type(message["content"]) == str: # text-only model
447+
text_messages.append(message)
448+
else: # vision model
449+
filtered_content_list = [
450+
content for content in message["content"] if content["type"] == "text"
451+
]
452+
text_messages.append(
453+
{"role": message["role"], "content": filtered_content_list}
454+
)
455+
456+
gen_params = {
457+
"model": model_name,
458+
"prompt": text_messages,
459+
"temperature": None,
460+
"top_p": None,
461+
"max_new_tokens": max_new_tokens,
462+
}
463+
logger.info(f"==== request ====\n{gen_params}")
464+
465+
res = client.chat.completions.create(
466+
model=model_name,
467+
messages=messages,
468+
max_tokens=max_new_tokens,
469+
stream=True,
470+
)
471+
text = ""
472+
for chunk_idx, chunk in enumerate(res):
473+
if len(chunk.choices) > 0:
474+
text += chunk.choices[0].delta.content or ""
475+
476+
data = {
477+
"text": text,
478+
"error_code": 0,
479+
}
480+
481+
if chunk_idx == 0:
482+
483+
if hasattr(chunk.choices[0].delta, "model"):
484+
data["ans_model"] = chunk.choices[0].delta.model
485+
486+
if hasattr(chunk, "router_outputs"):
487+
data["router_outputs"] = chunk.router_outputs
488+
489+
yield data
490+
491+
415492
def upload_openai_file_to_gcs(file_id):
416493
import openai
417494
from google.cloud import storage

fastchat/serve/gradio_block_arena_anony.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,25 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
175175
)
176176

177177

178-
SAMPLING_WEIGHTS = {}
178+
SAMPLING_WEIGHTS = {'claude-3-5-haiku-20241022': 1,
179+
'claude-3-5-sonnet-20240620': 1,
180+
'claude-3-5-sonnet-20241022': 1,
181+
'gpt-4-1106-preview': 1,
182+
'gpt-4-turbo-2024-04-09': 1,
183+
'gpt-4o-2024-05-13': 1,
184+
'gpt-4o-2024-08-06': 1,
185+
'gpt-4o-mini-2024-07-18': 1,
186+
'o1-mini': 1,
187+
'yi-lightning': 1,
188+
'llama-3.1-405b-instruct-fp8': 1,
189+
'llama-3.1-8b-instruct': 1,
190+
'llama-3.1-70b-instruct': 1,
191+
'gemini-1.5-pro-001': 1,
192+
'mistral-large-2407': 1,
193+
'qwen2.5-72b-instruct': 1,
194+
'gemma-2-27b-it': 1,
195+
'glm-4-plus': 1,
196+
'p2l': 1000}
179197

180198
# target model sampling weights will be boosted.
181199
BATTLE_TARGETS = {}

fastchat/serve/gradio_web_server.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import random
1212
import time
1313
import uuid
14-
from typing import List
14+
from typing import List, Dict
1515

1616
import gradio as gr
1717
import requests
@@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
119119
self.model_name = model_name
120120
self.oai_thread_id = None
121121
self.is_vision = is_vision
122+
self.ans_models = []
123+
self.router_outputs = []
122124

123125
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
124126
self.has_csam_image = False
@@ -128,6 +130,15 @@ def __init__(self, model_name, is_vision=False):
128130
self.regen_support = False
129131
self.init_system_prompt(self.conv, is_vision)
130132

133+
def update_ans_models(self, ans: str) -> None:
134+
135+
self.ans_models.append(ans)
136+
137+
def update_router_outputs(self, outputs: Dict[str, float]) -> None:
138+
139+
self.router_outputs.append(outputs)
140+
141+
131142
def init_system_prompt(self, conv, is_vision):
132143
system_prompt = conv.get_system_message(is_vision)
133144
if len(system_prompt) == 0:
@@ -154,6 +165,20 @@ def dict(self):
154165
}
155166
)
156167

168+
if self.ans_models:
169+
base.update(
170+
{
171+
"ans_models": self.ans_models,
172+
}
173+
)
174+
175+
if self.router_outputs:
176+
base.update(
177+
{
178+
"router_outputs": self.router_outputs,
179+
}
180+
)
181+
157182
if self.is_vision:
158183
base.update({"has_csam_image": self.has_csam_image})
159184
return base
@@ -420,7 +445,7 @@ def is_limit_reached(model_name, ip):
420445

421446

422447
def bot_response(
423-
state,
448+
state: State,
424449
temperature,
425450
top_p,
426451
max_new_tokens,
@@ -532,6 +557,23 @@ def bot_response(
532557
try:
533558
data = {"text": ""}
534559
for i, data in enumerate(stream_iter):
560+
561+
# Change for P2L:
562+
if i == 0:
563+
564+
if "ans_model" in data:
565+
566+
ans_model = data.get("ans_model")
567+
568+
state.update_ans_models(ans_model)
569+
570+
if "router_outputs" in data:
571+
572+
router_outputs = data.get("router_outputs")
573+
574+
state.update_router_outputs(router_outputs)
575+
576+
535577
if data["error_code"] == 0:
536578
output = data["text"].strip()
537579
conv.update_last_message(output + "▌")

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ dependencies = [
1919
]
2020

2121
[project.optional-dependencies]
22-
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
23-
webui = ["gradio>=4.10"]
22+
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai"]
23+
webui = ["gradio>=4.10", "plotly", "scipy"]
2424
train = ["einops", "flash-attn>=2.0", "wandb"]
2525
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
2626
dev = ["black==23.3.0", "pylint==2.8.2"]

0 commit comments

Comments
 (0)