Skip to content

Commit 3555d01

Browse files
committed
Merge remote-tracking branch 'fastchat/operation-202407' into moderation-log
2 parents 571f39e + 1ccbe8b commit 3555d01

File tree

7 files changed

+356
-149
lines changed

7 files changed

+356
-149
lines changed

fastchat/serve/gradio_block_arena_vision.py

Lines changed: 94 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import os
1212
import time
13+
from typing import List, Union
1314

1415
import gradio as gr
1516
from gradio.data_classes import FileData
@@ -27,6 +28,7 @@
2728
from fastchat.model.model_adapter import (
2829
get_conversation_template,
2930
)
31+
from fastchat.serve.gradio_global_state import Context
3032
from fastchat.serve.gradio_web_server import (
3133
get_model_description_md,
3234
acknowledgment_md,
@@ -153,14 +155,14 @@ def clear_history(request: gr.Request):
153155
ip = get_ip(request)
154156
logger.info(f"clear_history. ip: {ip}")
155157
state = None
156-
return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5
158+
return (state, [], enable_multimodal_clear_input, invisible_text, invisible_btn) + (disable_btn,) * 5
157159

158160

159161
def clear_history_example(request: gr.Request):
160162
ip = get_ip(request)
161163
logger.info(f"clear_history_example. ip: {ip}")
162164
state = None
163-
return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5
165+
return (state, [], enable_multimodal_keep_input, invisible_text, invisible_btn) + (disable_btn,) * 5
164166

165167

166168
# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -199,11 +201,16 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
199201
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
200202

201203
if state is None:
202-
state = State(model_selector, is_vision=True)
204+
if len(images) == 0:
205+
state = State(model_selector, is_vision=False)
206+
else:
207+
state = State(model_selector, is_vision=True)
203208

204209
if len(text) <= 0:
205210
state.skip_next = True
206-
return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5
211+
return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
212+
no_change_btn,
213+
) * 5
207214

208215
all_conv_text = state.conv.get_prompt()
209216
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
@@ -239,19 +246,29 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
239246
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
240247
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
241248
state.skip_next = True
242-
return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + (
249+
return (
250+
state,
251+
state.to_gradio_chatbot(),
252+
{"text": CONVERSATION_LIMIT_MSG},
253+
"",
243254
no_change_btn,
244-
) * 5
255+
) + (no_change_btn,) * 5
245256

246257
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
247258
text = _prepare_text_with_image(state, text, images)
248259
state.conv.append_message(state.conv.roles[0], text)
249260
state.conv.append_message(state.conv.roles[1], None)
250-
return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
261+
return (
262+
state,
263+
state.to_gradio_chatbot(),
264+
disable_multimodal,
265+
visible_text,
266+
enable_btn,
267+
) + (disable_btn,) * 5
251268

252269

253270
def build_single_vision_language_model_ui(
254-
models, add_promotion_links=False, random_questions=None
271+
context: Context, add_promotion_links=False, random_questions=None
255272
):
256273
promotion = (
257274
f"""
@@ -273,33 +290,29 @@ def build_single_vision_language_model_ui(
273290

274291
state = gr.State()
275292
gr.Markdown(notice_markdown, elem_id="notice_markdown")
293+
text_and_vision_models = list(set(context.text_models + context.vision_models))
294+
context_state = gr.State(context)
276295

277296
with gr.Group():
278297
with gr.Row(elem_id="model_selector_row"):
279298
model_selector = gr.Dropdown(
280-
choices=models,
281-
value=models[0] if len(models) > 0 else "",
299+
choices=text_and_vision_models,
300+
value=text_and_vision_models[0]
301+
if len(text_and_vision_models) > 0
302+
else "",
282303
interactive=True,
283304
show_label=False,
284305
container=False,
285306
)
286307

287308
with gr.Accordion(
288-
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
309+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
310+
open=False,
289311
):
290-
model_description_md = get_model_description_md(models)
312+
model_description_md = get_model_description_md(text_and_vision_models)
291313
gr.Markdown(model_description_md, elem_id="model_description_markdown")
292314

293315
with gr.Row():
294-
textbox = gr.MultimodalTextbox(
295-
file_types=["image"],
296-
show_label=False,
297-
placeholder="Enter your prompt or add image here",
298-
container=True,
299-
render=False,
300-
elem_id="input_box",
301-
)
302-
303316
with gr.Column(scale=2, visible=False) as image_column:
304317
imagebox = gr.Image(
305318
type="pil",
@@ -312,9 +325,24 @@ def build_single_vision_language_model_ui(
312325
)
313326

314327
with gr.Row():
315-
textbox.render()
316-
# with gr.Column(scale=1, min_width=50):
317-
# send_btn = gr.Button(value="Send", variant="primary")
328+
textbox = gr.Textbox(
329+
show_label=False,
330+
placeholder="👉 Enter your prompt and press ENTER",
331+
elem_id="input_box",
332+
visible=False,
333+
)
334+
335+
send_btn = gr.Button(
336+
value="Send", variant="primary", scale=0, visible=False, interactive=False
337+
)
338+
339+
multimodal_textbox = gr.MultimodalTextbox(
340+
file_types=["image"],
341+
show_label=False,
342+
placeholder="Enter your prompt or add image here",
343+
container=True,
344+
elem_id="input_box",
345+
)
318346

319347
with gr.Row(elem_id="buttons"):
320348
if random_questions:
@@ -328,22 +356,6 @@ def build_single_vision_language_model_ui(
328356
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
329357
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
330358

331-
cur_dir = os.path.dirname(os.path.abspath(__file__))
332-
333-
examples = gr.Examples(
334-
examples=[
335-
{
336-
"text": "How can I prepare a delicious meal using these ingredients?",
337-
"files": [f"{cur_dir}/example_images/fridge.jpg"],
338-
},
339-
{
340-
"text": "What might the woman on the right be thinking about?",
341-
"files": [f"{cur_dir}/example_images/distracted.jpg"],
342-
},
343-
],
344-
inputs=[textbox],
345-
)
346-
347359
with gr.Accordion("Parameters", open=False) as parameter_row:
348360
temperature = gr.Slider(
349361
minimum=0.0,
@@ -395,23 +407,50 @@ def build_single_vision_language_model_ui(
395407
[state, temperature, top_p, max_output_tokens],
396408
[state, chatbot] + btn_list,
397409
)
398-
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
410+
clear_btn.click(
411+
clear_history,
412+
None,
413+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
414+
)
399415

400416
model_selector.change(
401-
clear_history, None, [state, chatbot, textbox] + btn_list
402-
).then(set_visible_image, [textbox], [image_column])
403-
examples.dataset.click(
404-
clear_history_example, None, [state, chatbot, textbox] + btn_list
417+
clear_history,
418+
None,
419+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
420+
).then(set_visible_image, [multimodal_textbox], [image_column])
421+
422+
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
423+
set_visible_image, [multimodal_textbox], [image_column]
424+
).then(
425+
clear_history_example,
426+
None,
427+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
405428
)
406429

407-
textbox.input(add_image, [textbox], [imagebox]).then(
408-
set_visible_image, [textbox], [image_column]
409-
).then(clear_history_example, None, [state, chatbot, textbox] + btn_list)
430+
multimodal_textbox.submit(
431+
add_text,
432+
[state, model_selector, multimodal_textbox, context_state],
433+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
434+
).then(set_invisible_image, [], [image_column]).then(
435+
bot_response,
436+
[state, temperature, top_p, max_output_tokens],
437+
[state, chatbot] + btn_list,
438+
)
410439

411440
textbox.submit(
412441
add_text,
413-
[state, model_selector, textbox],
414-
[state, chatbot, textbox] + btn_list,
442+
[state, model_selector, textbox, context_state],
443+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
444+
).then(set_invisible_image, [], [image_column]).then(
445+
bot_response,
446+
[state, temperature, top_p, max_output_tokens],
447+
[state, chatbot] + btn_list,
448+
)
449+
450+
send_btn.click(
451+
add_text,
452+
[state, model_selector, textbox, context_state],
453+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
415454
).then(set_invisible_image, [], [image_column]).then(
416455
bot_response,
417456
[state, temperature, top_p, max_output_tokens],
@@ -422,9 +461,11 @@ def build_single_vision_language_model_ui(
422461
random_btn.click(
423462
get_vqa_sample, # First, get the VQA sample
424463
[], # Pass the path to the VQA samples
425-
[textbox, imagebox], # Outputs are textbox and imagebox
426-
).then(set_visible_image, [textbox], [image_column]).then(
427-
clear_history_example, None, [state, chatbot, textbox] + btn_list
464+
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
465+
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
466+
clear_history_example,
467+
None,
468+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
428469
)
429470

430471
return [state, model_selector]

fastchat/serve/gradio_block_arena_vision_anony.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import gradio as gr
1010
import numpy as np
11+
from typing import Union
1112

1213
from fastchat.constants import (
1314
TEXT_MODERATION_MSG,
@@ -48,7 +49,6 @@
4849
regenerate,
4950
clear_history,
5051
share_click,
51-
add_text,
5252
bot_response_multi,
5353
set_global_vars_anony,
5454
load_demo_side_by_side_anony,
@@ -75,6 +75,7 @@
7575
BaseContentModerator,
7676
AzureAndOpenAIContentModerator,
7777
)
78+
from fastchat.serve.gradio_global_state import Context
7879
from fastchat.serve.remote_logger import get_remote_logger
7980
from fastchat.utils import (
8081
build_logger,
@@ -121,16 +122,12 @@ def get_vqa_sample():
121122
return (res, path)
122123

123124

124-
def load_demo_side_by_side_vision_anony(all_text_models, all_vl_models, url_params):
125-
global text_models, vl_models
126-
text_models = all_text_models
127-
vl_models = all_vl_models
128-
129-
states = (None,) * num_sides
130-
selector_updates = (
125+
def load_demo_side_by_side_vision_anony():
126+
states = [None] * num_sides
127+
selector_updates = [
131128
gr.Markdown(visible=True),
132129
gr.Markdown(visible=True),
133-
)
130+
]
134131

135132
return states + selector_updates
136133

@@ -256,7 +253,13 @@ def clear_history(request: gr.Request):
256253

257254

258255
def add_text(
259-
state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request
256+
state0,
257+
state1,
258+
model_selector0,
259+
model_selector1,
260+
chat_input: Union[str, dict],
261+
context: Context,
262+
request: gr.Request,
260263
):
261264
if isinstance(chat_input, dict):
262265
text, images = chat_input["text"], chat_input["files"]
@@ -275,7 +278,7 @@ def add_text(
275278

276279
if len(images) > 0:
277280
model_left, model_right = get_battle_pair(
278-
vl_models,
281+
context.all_vision_models,
279282
VISION_BATTLE_TARGETS,
280283
VISION_OUTAGE_MODELS,
281284
VISION_SAMPLING_WEIGHTS,
@@ -287,7 +290,7 @@ def add_text(
287290
]
288291
else:
289292
model_left, model_right = get_battle_pair(
290-
text_models,
293+
context.all_text_models,
291294
BATTLE_TARGETS,
292295
OUTAGE_MODELS,
293296
SAMPLING_WEIGHTS,
@@ -408,8 +411,8 @@ def add_text(
408411
)
409412

410413

411-
def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=None):
412-
notice_markdown = f"""
414+
def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
415+
notice_markdown = """
413416
# ⚔️ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild
414417
[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena)
415418
@@ -432,7 +435,9 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
432435
chatbots = [None] * num_sides
433436
show_vote_buttons = gr.State(True)
434437

438+
context_state = gr.State(context)
435439
gr.Markdown(notice_markdown, elem_id="notice_markdown")
440+
text_and_vision_models = list(set(context.text_models + context.vision_models))
436441

437442
with gr.Row():
438443
with gr.Column(scale=2, visible=False) as image_column:
@@ -445,11 +450,11 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
445450
with gr.Column(scale=5):
446451
with gr.Group(elem_id="share-region-anony"):
447452
with gr.Accordion(
448-
f"🔍 Expand to see the descriptions of {len(text_models) + len(vl_models)} models",
453+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
449454
open=False,
450455
):
451456
model_description_md = get_model_description_md(
452-
text_models + vl_models
457+
text_and_vision_models
453458
)
454459
gr.Markdown(
455460
model_description_md, elem_id="model_description_markdown"
@@ -630,7 +635,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
630635

631636
multimodal_textbox.submit(
632637
add_text,
633-
states + model_selectors + [multimodal_textbox],
638+
states + model_selectors + [multimodal_textbox, context_state],
634639
states
635640
+ chatbots
636641
+ [multimodal_textbox, textbox, send_btn]
@@ -650,7 +655,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
650655

651656
textbox.submit(
652657
add_text,
653-
states + model_selectors + [textbox],
658+
states + model_selectors + [textbox, context_state],
654659
states
655660
+ chatbots
656661
+ [multimodal_textbox, textbox, send_btn]
@@ -670,7 +675,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
670675

671676
send_btn.click(
672677
add_text,
673-
states + model_selectors + [textbox],
678+
states + model_selectors + [textbox, context_state],
674679
states
675680
+ chatbots
676681
+ [multimodal_textbox, textbox, send_btn]

0 commit comments

Comments
 (0)