Skip to content

Commit 1ccbe8b

Browse files
authored
Enable vision arena across all tabs (#3483)
1 parent 68023e1 commit 1ccbe8b

File tree

7 files changed

+385
-154
lines changed

7 files changed

+385
-154
lines changed

fastchat/serve/gradio_block_arena_vision.py

Lines changed: 124 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import os
1212
import time
13+
from typing import List, Union
1314

1415
import gradio as gr
1516
from gradio.data_classes import FileData
@@ -26,6 +27,7 @@
2627
from fastchat.model.model_adapter import (
2728
get_conversation_template,
2829
)
30+
from fastchat.serve.gradio_global_state import Context
2931
from fastchat.serve.gradio_web_server import (
3032
get_model_description_md,
3133
acknowledgment_md,
@@ -144,14 +146,18 @@ def clear_history(request: gr.Request):
144146
ip = get_ip(request)
145147
logger.info(f"clear_history. ip: {ip}")
146148
state = None
147-
return (state, [], None) + (disable_btn,) * 5
149+
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
150+
disable_btn,
151+
) * 5
148152

149153

150154
def clear_history_example(request: gr.Request):
151155
ip = get_ip(request)
152156
logger.info(f"clear_history_example. ip: {ip}")
153157
state = None
154-
return (state, [], enable_multimodal) + (disable_btn,) * 5
158+
return (state, [], enable_multimodal, invisible_text, invisible_btn) + (
159+
disable_btn,
160+
) * 5
155161

156162

157163
# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -209,17 +215,40 @@ def moderate_input(state, text, all_conv_text, model_list, images, ip):
209215
return text, image_flagged, csam_flagged
210216

211217

212-
def add_text(state, model_selector, chat_input, request: gr.Request):
213-
text, images = chat_input["text"], chat_input["files"]
218+
def add_text(
219+
state,
220+
model_selector,
221+
chat_input: Union[str, dict],
222+
context: Context,
223+
request: gr.Request,
224+
):
225+
if isinstance(chat_input, dict):
226+
text, images = chat_input["text"], chat_input["files"]
227+
else:
228+
text, images = chat_input, []
229+
230+
if (
231+
len(images) > 0
232+
and model_selector in context.text_models
233+
and model_selector not in context.vision_models
234+
):
235+
gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
236+
images = []
237+
214238
ip = get_ip(request)
215239
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
216240

217241
if state is None:
218-
state = State(model_selector, is_vision=True)
242+
if len(images) == 0:
243+
state = State(model_selector, is_vision=False)
244+
else:
245+
state = State(model_selector, is_vision=True)
219246

220247
if len(text) <= 0:
221248
state.skip_next = True
222-
return (state, state.to_gradio_chatbot(), None) + (no_change_btn,) * 5
249+
return (state, state.to_gradio_chatbot(), None, "", no_change_btn) + (
250+
no_change_btn,
251+
) * 5
223252

224253
all_conv_text = state.conv.get_prompt()
225254
all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
@@ -233,26 +262,40 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
233262
if image_flagged:
234263
logger.info(f"image flagged. ip: {ip}. text: {text}")
235264
state.skip_next = True
236-
return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + (
265+
return (
266+
state,
267+
state.to_gradio_chatbot(),
268+
{"text": IMAGE_MODERATION_MSG},
269+
"",
237270
no_change_btn,
238-
) * 5
271+
) + (no_change_btn,) * 5
239272

240273
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
241274
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
242275
state.skip_next = True
243-
return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + (
276+
return (
277+
state,
278+
state.to_gradio_chatbot(),
279+
{"text": CONVERSATION_LIMIT_MSG},
280+
"",
244281
no_change_btn,
245-
) * 5
282+
) + (no_change_btn,) * 5
246283

247284
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
248285
text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag)
249286
state.conv.append_message(state.conv.roles[0], text)
250287
state.conv.append_message(state.conv.roles[1], None)
251-
return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
288+
return (
289+
state,
290+
state.to_gradio_chatbot(),
291+
disable_multimodal,
292+
visible_text,
293+
enable_btn,
294+
) + (disable_btn,) * 5
252295

253296

254297
def build_single_vision_language_model_ui(
255-
models, add_promotion_links=False, random_questions=None
298+
context: Context, add_promotion_links=False, random_questions=None
256299
):
257300
promotion = (
258301
"""
@@ -272,33 +315,29 @@ def build_single_vision_language_model_ui(
272315

273316
state = gr.State()
274317
gr.Markdown(notice_markdown, elem_id="notice_markdown")
318+
text_and_vision_models = list(set(context.text_models + context.vision_models))
319+
context_state = gr.State(context)
275320

276321
with gr.Group():
277322
with gr.Row(elem_id="model_selector_row"):
278323
model_selector = gr.Dropdown(
279-
choices=models,
280-
value=models[0] if len(models) > 0 else "",
324+
choices=text_and_vision_models,
325+
value=text_and_vision_models[0]
326+
if len(text_and_vision_models) > 0
327+
else "",
281328
interactive=True,
282329
show_label=False,
283330
container=False,
284331
)
285332

286333
with gr.Accordion(
287-
f"🔍 Expand to see the descriptions of {len(models)} models", open=False
334+
f"🔍 Expand to see the descriptions of {len(text_and_vision_models)} models",
335+
open=False,
288336
):
289-
model_description_md = get_model_description_md(models)
337+
model_description_md = get_model_description_md(text_and_vision_models)
290338
gr.Markdown(model_description_md, elem_id="model_description_markdown")
291339

292340
with gr.Row():
293-
textbox = gr.MultimodalTextbox(
294-
file_types=["image"],
295-
show_label=False,
296-
placeholder="Enter your prompt or add image here",
297-
container=True,
298-
render=False,
299-
elem_id="input_box",
300-
)
301-
302341
with gr.Column(scale=2, visible=False) as image_column:
303342
imagebox = gr.Image(
304343
type="pil",
@@ -311,9 +350,24 @@ def build_single_vision_language_model_ui(
311350
)
312351

313352
with gr.Row():
314-
textbox.render()
315-
# with gr.Column(scale=1, min_width=50):
316-
# send_btn = gr.Button(value="Send", variant="primary")
353+
textbox = gr.Textbox(
354+
show_label=False,
355+
placeholder="👉 Enter your prompt and press ENTER",
356+
elem_id="input_box",
357+
visible=False,
358+
)
359+
360+
send_btn = gr.Button(
361+
value="Send", variant="primary", scale=0, visible=False, interactive=False
362+
)
363+
364+
multimodal_textbox = gr.MultimodalTextbox(
365+
file_types=["image"],
366+
show_label=False,
367+
placeholder="Enter your prompt or add image here",
368+
container=True,
369+
elem_id="input_box",
370+
)
317371

318372
with gr.Row(elem_id="buttons"):
319373
if random_questions:
@@ -327,22 +381,6 @@ def build_single_vision_language_model_ui(
327381
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
328382
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
329383

330-
cur_dir = os.path.dirname(os.path.abspath(__file__))
331-
332-
examples = gr.Examples(
333-
examples=[
334-
{
335-
"text": "How can I prepare a delicious meal using these ingredients?",
336-
"files": [f"{cur_dir}/example_images/fridge.jpg"],
337-
},
338-
{
339-
"text": "What might the woman on the right be thinking about?",
340-
"files": [f"{cur_dir}/example_images/distracted.jpg"],
341-
},
342-
],
343-
inputs=[textbox],
344-
)
345-
346384
with gr.Accordion("Parameters", open=False) as parameter_row:
347385
temperature = gr.Slider(
348386
minimum=0.0,
@@ -394,23 +432,50 @@ def build_single_vision_language_model_ui(
394432
[state, temperature, top_p, max_output_tokens],
395433
[state, chatbot] + btn_list,
396434
)
397-
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
435+
clear_btn.click(
436+
clear_history,
437+
None,
438+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
439+
)
398440

399441
model_selector.change(
400-
clear_history, None, [state, chatbot, textbox] + btn_list
401-
).then(set_visible_image, [textbox], [image_column])
402-
examples.dataset.click(
403-
clear_history_example, None, [state, chatbot, textbox] + btn_list
442+
clear_history,
443+
None,
444+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
445+
).then(set_visible_image, [multimodal_textbox], [image_column])
446+
447+
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
448+
set_visible_image, [multimodal_textbox], [image_column]
449+
).then(
450+
clear_history_example,
451+
None,
452+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
404453
)
405454

406-
textbox.input(add_image, [textbox], [imagebox]).then(
407-
set_visible_image, [textbox], [image_column]
408-
).then(clear_history_example, None, [state, chatbot, textbox] + btn_list)
455+
multimodal_textbox.submit(
456+
add_text,
457+
[state, model_selector, multimodal_textbox, context_state],
458+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
459+
).then(set_invisible_image, [], [image_column]).then(
460+
bot_response,
461+
[state, temperature, top_p, max_output_tokens],
462+
[state, chatbot] + btn_list,
463+
)
409464

410465
textbox.submit(
411466
add_text,
412-
[state, model_selector, textbox],
413-
[state, chatbot, textbox] + btn_list,
467+
[state, model_selector, textbox, context_state],
468+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
469+
).then(set_invisible_image, [], [image_column]).then(
470+
bot_response,
471+
[state, temperature, top_p, max_output_tokens],
472+
[state, chatbot] + btn_list,
473+
)
474+
475+
send_btn.click(
476+
add_text,
477+
[state, model_selector, textbox, context_state],
478+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
414479
).then(set_invisible_image, [], [image_column]).then(
415480
bot_response,
416481
[state, temperature, top_p, max_output_tokens],
@@ -421,9 +486,11 @@ def build_single_vision_language_model_ui(
421486
random_btn.click(
422487
get_vqa_sample, # First, get the VQA sample
423488
[], # Pass the path to the VQA samples
424-
[textbox, imagebox], # Outputs are textbox and imagebox
425-
).then(set_visible_image, [textbox], [image_column]).then(
426-
clear_history_example, None, [state, chatbot, textbox] + btn_list
489+
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
490+
).then(set_visible_image, [multimodal_textbox], [image_column]).then(
491+
clear_history_example,
492+
None,
493+
[state, chatbot, multimodal_textbox, textbox, send_btn] + btn_list,
427494
)
428495

429496
return [state, model_selector]

0 commit comments

Comments
 (0)