Skip to content

Commit 9aa5cad

Browse files
committed
format
1 parent 045be20 commit 9aa5cad

File tree

2 files changed

+6
-15
lines changed

2 files changed

+6
-15
lines changed

fastchat/serve/api_provider.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def p2l_api_stream_iter(
433433
api_key=None,
434434
):
435435
import openai
436-
436+
437437
client = openai.OpenAI(
438438
base_url=api_base,
439439
api_key=api_key or "-",
@@ -479,13 +479,12 @@ def p2l_api_stream_iter(
479479
}
480480

481481
if chunk_idx == 0:
482-
483482
if hasattr(chunk.choices[0].delta, "model"):
484483
data["ans_model"] = chunk.choices[0].delta.model
485-
484+
486485
if hasattr(chunk, "router_outputs"):
487486
data["router_outputs"] = chunk.router_outputs
488-
487+
489488
yield data
490489

491490

fastchat/serve/gradio_web_server.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,11 @@ def __init__(self, model_name, is_vision=False):
131131
self.init_system_prompt(self.conv, is_vision)
132132

133133
def update_ans_models(self, ans: str) -> None:
134-
135134
self.ans_models.append(ans)
136135

137136
def update_router_outputs(self, outputs: Dict[str, float]) -> None:
138-
139137
self.router_outputs.append(outputs)
140138

141-
142139
def init_system_prompt(self, conv, is_vision):
143140
system_prompt = conv.get_system_message(is_vision)
144141
if len(system_prompt) == 0:
@@ -171,7 +168,7 @@ def dict(self):
171168
"ans_models": self.ans_models,
172169
}
173170
)
174-
171+
175172
if self.router_outputs:
176173
base.update(
177174
{
@@ -557,23 +554,18 @@ def bot_response(
557554
try:
558555
data = {"text": ""}
559556
for i, data in enumerate(stream_iter):
560-
561557
# Change for P2L:
562558
if i == 0:
563-
564559
if "ans_model" in data:
565-
566560
ans_model = data.get("ans_model")
567-
561+
568562
state.update_ans_models(ans_model)
569-
570-
if "router_outputs" in data:
571563

564+
if "router_outputs" in data:
572565
router_outputs = data.get("router_outputs")
573566

574567
state.update_router_outputs(router_outputs)
575568

576-
577569
if data["error_code"] == 0:
578570
output = data["text"].strip()
579571
conv.update_last_message(output + "▌")

0 commit comments

Comments
 (0)