Skip to content

Commit 68023e1

Browse files
BabyChouSrinfwinstonsimon-mo
authored
Use Reka Python SDK and add script for benchmarking and add send_btn (#3413)
Co-authored-by: Wei-Lin Chiang <[email protected]> Co-authored-by: Wei-Lin Chiang <[email protected]> Co-authored-by: simon-mo <[email protected]>
1 parent a71e3c6 commit 68023e1

File tree

6 files changed

+244
-59
lines changed

6 files changed

+244
-59
lines changed

fastchat/conversation.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -524,30 +524,55 @@ def to_anthropic_vision_api_messages(self):
524524

525525
def to_reka_api_messages(self):
526526
from fastchat.serve.vision.image import ImageFormat
527+
from reka import ChatMessage, TypedMediaContent, TypedText
527528

528529
ret = []
529530
for i, (_, msg) in enumerate(self.messages[self.offset :]):
530531
if i % 2 == 0:
531532
if type(msg) == tuple:
532533
text, images = msg
533534
for image in images:
534-
if image.image_format == ImageFormat.URL:
535-
ret.append(
536-
{"type": "human", "text": text, "media_url": image.url}
537-
)
538-
elif image.image_format == ImageFormat.BYTES:
535+
if image.image_format == ImageFormat.BYTES:
539536
ret.append(
540-
{
541-
"type": "human",
542-
"text": text,
543-
"media_url": f"data:image/{image.filetype};base64,{image.base64_str}",
544-
}
537+
ChatMessage(
538+
content=[
539+
TypedText(
540+
type="text",
541+
text=text,
542+
),
543+
TypedMediaContent(
544+
type="image_url",
545+
image_url=f"data:image/{image.filetype};base64,{image.base64_str}",
546+
),
547+
],
548+
role="user",
549+
)
545550
)
546551
else:
547-
ret.append({"type": "human", "text": msg})
552+
ret.append(
553+
ChatMessage(
554+
content=[
555+
TypedText(
556+
type="text",
557+
text=msg,
558+
)
559+
],
560+
role="user",
561+
)
562+
)
548563
else:
549564
if msg is not None:
550-
ret.append({"type": "model", "text": msg})
565+
ret.append(
566+
ChatMessage(
567+
content=[
568+
TypedText(
569+
type="text",
570+
text=msg,
571+
)
572+
],
573+
role="assistant",
574+
)
575+
)
551576

552577
return ret
553578

fastchat/serve/api_provider.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,13 @@ def reka_api_stream_iter(
10761076
api_key: Optional[str] = None, # default is env var CO_API_KEY
10771077
api_base: Optional[str] = None,
10781078
):
1079+
from reka.client import Reka
1080+
from reka import TypedText
1081+
10791082
api_key = api_key or os.environ["REKA_API_KEY"]
10801083

1084+
client = Reka(api_key=api_key)
1085+
10811086
use_search_engine = False
10821087
if "-online" in model_name:
10831088
model_name = model_name.replace("-online", "")
@@ -1094,34 +1099,27 @@ def reka_api_stream_iter(
10941099

10951100
# Make requests for logging
10961101
text_messages = []
1097-
for message in messages:
1098-
text_messages.append({"type": message["type"], "text": message["text"]})
1102+
for turn in messages:
1103+
for message in turn.content:
1104+
if isinstance(message, TypedText):
1105+
text_messages.append({"type": message.type, "text": message.text})
10991106
logged_request = dict(request)
11001107
logged_request["conversation_history"] = text_messages
11011108

11021109
logger.info(f"==== request ====\n{logged_request}")
11031110

1104-
response = requests.post(
1105-
api_base,
1106-
stream=True,
1107-
json=request,
1108-
headers={
1109-
"X-Api-Key": api_key,
1110-
},
1111+
response = client.chat.create_stream(
1112+
messages=messages,
1113+
max_tokens=max_new_tokens,
1114+
top_p=top_p,
1115+
model=model_name,
11111116
)
11121117

1113-
if response.status_code != 200:
1114-
error_message = response.text
1115-
logger.error(f"==== error from reka api: {error_message} ====")
1116-
yield {
1117-
"text": f"**API REQUEST ERROR** Reason: {error_message}",
1118-
"error_code": 1,
1119-
}
1120-
return
1121-
1122-
for line in response.iter_lines():
1123-
line = line.decode("utf8")
1124-
if not line.startswith("data: "):
1125-
continue
1126-
gen = json.loads(line[6:])
1127-
yield {"text": gen["text"], "error_code": 0}
1118+
for chunk in response:
1119+
try:
1120+
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
1121+
except:
1122+
yield {
1123+
"text": f"**API REQUEST ERROR** ",
1124+
"error_code": 1,
1125+
}

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def clear_history_example(request: gr.Request):
134134
[None] * num_sides
135135
+ [None] * num_sides
136136
+ anony_names
137-
+ [enable_multimodal, invisible_text]
137+
+ [enable_multimodal, invisible_text, invisible_btn]
138138
+ [invisible_btn] * 4
139139
+ [disable_btn] * 2
140140
+ [enable_btn]
@@ -239,7 +239,7 @@ def clear_history(request: gr.Request):
239239
[None] * num_sides
240240
+ [None] * num_sides
241241
+ anony_names
242-
+ [enable_multimodal, invisible_text]
242+
+ [enable_multimodal, invisible_text, invisible_btn]
243243
+ [invisible_btn] * 4
244244
+ [disable_btn] * 2
245245
+ [enable_btn]
@@ -297,7 +297,7 @@ def add_text(
297297
return (
298298
states
299299
+ [x.to_gradio_chatbot() for x in states]
300-
+ [None, ""]
300+
+ [None, "", no_change_btn]
301301
+ [
302302
no_change_btn,
303303
]
@@ -321,7 +321,7 @@ def add_text(
321321
return (
322322
states
323323
+ [x.to_gradio_chatbot() for x in states]
324-
+ [{"text": CONVERSATION_LIMIT_MSG}, ""]
324+
+ [{"text": CONVERSATION_LIMIT_MSG}, "", no_change_btn]
325325
+ [
326326
no_change_btn,
327327
]
@@ -342,6 +342,7 @@ def add_text(
342342
+ " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION."
343343
},
344344
"",
345+
no_change_btn,
345346
]
346347
+ [no_change_btn] * 7
347348
+ [""]
@@ -363,7 +364,7 @@ def add_text(
363364
return (
364365
states
365366
+ [x.to_gradio_chatbot() for x in states]
366-
+ [disable_multimodal, visible_text]
367+
+ [disable_multimodal, visible_text, enable_btn]
367368
+ [
368369
disable_btn,
369370
]
@@ -464,7 +465,9 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
464465
placeholder="Enter your prompt or add image here",
465466
elem_id="input_box",
466467
)
467-
# send_btn = gr.Button(value="Send", variant="primary", scale=0)
468+
send_btn = gr.Button(
469+
value="Send", variant="primary", scale=0, visible=False, interactive=False
470+
)
468471

469472
with gr.Row() as button_row:
470473
if random_questions:
@@ -548,7 +551,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
548551
states
549552
+ chatbots
550553
+ model_selectors
551-
+ [multimodal_textbox, textbox]
554+
+ [multimodal_textbox, textbox, send_btn]
552555
+ btn_list
553556
+ [random_btn]
554557
+ [slow_warning],
@@ -581,15 +584,19 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
581584
).then(
582585
clear_history_example,
583586
None,
584-
states + chatbots + model_selectors + [multimodal_textbox, textbox] + btn_list,
587+
states
588+
+ chatbots
589+
+ model_selectors
590+
+ [multimodal_textbox, textbox, send_btn]
591+
+ btn_list,
585592
)
586593

587594
multimodal_textbox.submit(
588595
add_text,
589596
states + model_selectors + [multimodal_textbox],
590597
states
591598
+ chatbots
592-
+ [multimodal_textbox, textbox]
599+
+ [multimodal_textbox, textbox, send_btn]
593600
+ btn_list
594601
+ [random_btn]
595602
+ [slow_warning],
@@ -608,7 +615,26 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
608615
states + model_selectors + [textbox],
609616
states
610617
+ chatbots
611-
+ [multimodal_textbox, textbox]
618+
+ [multimodal_textbox, textbox, send_btn]
619+
+ btn_list
620+
+ [random_btn]
621+
+ [slow_warning],
622+
).then(
623+
bot_response_multi,
624+
states + [temperature, top_p, max_output_tokens],
625+
states + chatbots + btn_list,
626+
).then(
627+
flash_buttons,
628+
[],
629+
btn_list,
630+
)
631+
632+
send_btn.click(
633+
add_text,
634+
states + model_selectors + [textbox],
635+
states
636+
+ chatbots
637+
+ [multimodal_textbox, textbox, send_btn]
612638
+ btn_list
613639
+ [random_btn]
614640
+ [slow_warning],
@@ -633,7 +659,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
633659
states
634660
+ chatbots
635661
+ model_selectors
636-
+ [multimodal_textbox, textbox]
662+
+ [multimodal_textbox, textbox, send_btn]
637663
+ btn_list
638664
+ [random_btn],
639665
)

fastchat/serve/monitor/monitor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ def build_arena_tab(
490490
)
491491
return
492492

493+
round_digit = None if vision else None
493494
arena_dfs = {}
494495
category_elo_results = {}
495496
last_updated_time = elo_results["full"]["last_updated_datetime"].split(" ")[0]
@@ -512,6 +513,7 @@ def update_leaderboard_and_plots(category):
512513
arena_df,
513514
model_table_df,
514515
arena_subset_df=arena_subset_df if category != "Overall" else None,
516+
round_digit=round_digit,
515517
)
516518
if category != "Overall":
517519
arena_values = update_leaderboard_df(arena_values)
@@ -665,9 +667,7 @@ def update_leaderboard_and_plots(category):
665667
elem_id="leaderboard_markdown",
666668
)
667669

668-
if not vision:
669-
# only live update the text tab
670-
leader_component_values[:] = [default_md, p1, p2, p3, p4]
670+
leader_component_values[:] = [default_md, p1, p2, p3, p4]
671671

672672
if show_plot:
673673
more_stats_md = gr.Markdown(
@@ -740,7 +740,7 @@ def build_full_leaderboard_tab(elo_results, model_table_df):
740740

741741

742742
def build_leaderboard_tab(
743-
elo_results_file, leaderboard_table_file, show_plot=False, mirror=False
743+
elo_results_file, leaderboard_table_file, vision=True, show_plot=False, mirror=False
744744
):
745745
if elo_results_file is None: # Do live update
746746
default_md = "Loading ..."
@@ -776,14 +776,15 @@ def build_leaderboard_tab(
776776
default_md,
777777
show_plot=show_plot,
778778
)
779-
with gr.Tab("📣 NEW: Arena (Vision)", id=1):
780-
build_arena_tab(
781-
elo_results_vision,
782-
model_table_df,
783-
default_md,
784-
vision=True,
785-
show_plot=show_plot,
786-
)
779+
if vision:
780+
with gr.Tab("📣 NEW: Arena (Vision)", id=1):
781+
build_arena_tab(
782+
elo_results_vision,
783+
model_table_df,
784+
default_md,
785+
vision=True,
786+
show_plot=show_plot,
787+
)
787788
with gr.Tab("Full Leaderboard", id=2):
788789
build_full_leaderboard_tab(elo_results_text, model_table_df)
789790

playground/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)