Skip to content

Commit a2200e4

Browse files
committed
Merge with unified vision arena
1 parent fe45c6f commit a2200e4

File tree

6 files changed

+40
-33
lines changed

6 files changed

+40
-33
lines changed

fastchat/serve/gradio_block_arena_vision.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,19 @@ def convert_images_to_conversation_format(images):
199199
return conv_images
200200

201201

202-
def add_text(state, model_selector, chat_input, request: gr.Request):
203-
text, images = chat_input["text"], chat_input["files"]
202+
def add_text(state, model_selector, chat_input, context: Context, request: gr.Request):
203+
if isinstance(chat_input, dict):
204+
text, images = chat_input["text"], chat_input["files"]
205+
else:
206+
text, images = chat_input, []
207+
208+
if (
209+
len(images) > 0
210+
and model_selector in context.text_models
211+
and model_selector not in context.vision_models
212+
):
213+
gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
214+
images = []
204215
ip = get_ip(request)
205216
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
206217

@@ -222,11 +233,13 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
222233
images = convert_images_to_conversation_format(images)
223234

224235
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
236+
moderation_image_input = images[0] if len(images) > 0 else None
225237
moderation_type_to_response_map = (
226238
state.content_moderator.image_and_text_moderation_filter(
227-
images[0], text, [state.model_name], do_moderation=False
239+
moderation_image_input, text, [state.model_name], do_moderation=False
228240
)
229241
)
242+
230243
text_flagged, nsfw_flag, csam_flag = (
231244
moderation_type_to_response_map["text_moderation"]["flagged"],
232245
moderation_type_to_response_map["nsfw_moderation"]["flagged"],
@@ -245,7 +258,9 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
245258
state.conv.append_message(state.conv.roles[0], post_processed_text)
246259
state.skip_next = True
247260
gr.Warning(MODERATION_MSG)
248-
return (state, gradio_chatbot_before_user_input, None) + (no_change_btn,) * 5
261+
return (state, gradio_chatbot_before_user_input, None, "", no_change_btn) + (
262+
no_change_btn,
263+
) * 5
249264

250265
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
251266
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,18 +322,12 @@ def add_text(
322322
images = convert_images_to_conversation_format(images)
323323

324324
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
325-
if len(images) > 0:
326-
moderation_type_to_response_map = states[
327-
0
328-
].content_moderator.image_and_text_moderation_filter(
329-
images[0], text, model_list, do_moderation=True
330-
)
331-
else:
332-
moderation_type_to_response_map = states[
333-
0
334-
].content_moderator.image_and_text_moderation_filter(
335-
None, text, model_list, do_moderation=True
336-
)
325+
moderation_image_input = images[0] if len(images) > 0 else None
326+
moderation_type_to_response_map = states[
327+
0
328+
].content_moderator.image_and_text_moderation_filter(
329+
moderation_image_input, text, model_list, do_moderation=True
330+
)
337331
text_flagged, nsfw_flag, csam_flag = (
338332
moderation_type_to_response_map["text_moderation"]["flagged"],
339333
moderation_type_to_response_map["nsfw_moderation"]["flagged"],

fastchat/serve/gradio_block_arena_vision_named.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
convert_images_to_conversation_format,
3737
enable_multimodal_keep_input,
3838
enable_multimodal_clear_input,
39-
enable_multimodal,
4039
disable_multimodal,
4140
invisible_text,
4241
invisible_btn,
@@ -254,18 +253,13 @@ def add_text(
254253
images = convert_images_to_conversation_format(images)
255254

256255
# Use the first state to get the moderation response because this is based on user input so it is independent of the model
257-
if len(images) > 0:
258-
moderation_type_to_response_map = states[
259-
0
260-
].content_moderator.image_and_text_moderation_filter(
261-
images[0], text, model_list, do_moderation=False
262-
)
263-
else:
264-
moderation_type_to_response_map = states[
265-
0
266-
].content_moderator.image_and_text_moderation_filter(
267-
None, text, model_list, do_moderation=False
268-
)
256+
moderation_image_input = images[0] if len(images) > 0 else None
257+
moderation_type_to_response_map = states[
258+
0
259+
].content_moderator.image_and_text_moderation_filter(
260+
moderation_image_input, text, model_list, do_moderation=False
261+
)
262+
269263
text_flagged, nsfw_flag, csam_flag = (
270264
moderation_type_to_response_map["text_moderation"]["flagged"],
271265
moderation_type_to_response_map["nsfw_moderation"]["flagged"],
@@ -309,7 +303,7 @@ def add_text(
309303
return (
310304
states
311305
+ gradio_chatbot_list
312-
+ [None]
306+
+ [None, "", no_change_btn]
313307
+ [
314308
no_change_btn,
315309
]

fastchat/serve/gradio_web_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def bot_response(
477477
if state.content_moderator.text_flagged or state.content_moderator.nsfw_flagged:
478478
start_tstamp = time.time()
479479
finish_tstamp = start_tstamp
480-
conv.save_new_images(
480+
state.conv.save_new_images(
481481
has_csam_images=state.has_csam_image,
482482
use_remote_storage=use_remote_storage,
483483
)

fastchat/serve/gradio_web_server_multi.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def build_demo(context: Context, elo_results_file: str, leaderboard_table_file):
192192
if elo_results_file:
193193
with gr.Tab("🏆 Leaderboard", id=3):
194194
build_leaderboard_tab(
195-
elo_results_file, leaderboard_table_file, show_plot=True
195+
elo_results_file,
196+
leaderboard_table_file,
197+
arena_hard_leaderboard=None,
198+
show_plot=True,
196199
)
197200

198201
with gr.Tab("ℹ️ About Us", id=4):

fastchat/serve/monitor/monitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,8 @@ def build_leaderboard_tab(
864864
vision=True,
865865
show_plot=show_plot,
866866
)
867+
868+
model_to_score = {}
867869
if arena_hard_leaderboard is not None:
868870
with gr.Tab("Arena-Hard-Auto", id=3):
869871
dataFrame = arena_hard_process(
@@ -883,7 +885,6 @@ def build_leaderboard_tab(
883885
"avg_tokens": "Average Tokens",
884886
}
885887
)
886-
model_to_score = {}
887888
for i in range(len(dataFrame)):
888889
model_to_score[dataFrame.loc[i, "Model"]] = dataFrame.loc[
889890
i, "Win-rate"

0 commit comments

Comments
 (0)