1010import json
1111import os
1212import time
13+ from typing import List , Union
1314
1415import gradio as gr
1516from gradio .data_classes import FileData
2627from fastchat .model .model_adapter import (
2728 get_conversation_template ,
2829)
30+ from fastchat .serve .gradio_global_state import Context
2931from fastchat .serve .gradio_web_server import (
3032 get_model_description_md ,
3133 acknowledgment_md ,
@@ -144,14 +146,18 @@ def clear_history(request: gr.Request):
144146 ip = get_ip (request )
145147 logger .info (f"clear_history. ip: { ip } " )
146148 state = None
147- return (state , [], None ) + (disable_btn ,) * 5
149+ return (state , [], enable_multimodal , invisible_text , invisible_btn ) + (
150+ disable_btn ,
151+ ) * 5
148152
149153
150154def clear_history_example (request : gr .Request ):
151155 ip = get_ip (request )
152156 logger .info (f"clear_history_example. ip: { ip } " )
153157 state = None
154- return (state , [], enable_multimodal ) + (disable_btn ,) * 5
158+ return (state , [], enable_multimodal , invisible_text , invisible_btn ) + (
159+ disable_btn ,
160+ ) * 5
155161
156162
157163# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -209,17 +215,40 @@ def moderate_input(state, text, all_conv_text, model_list, images, ip):
209215 return text , image_flagged , csam_flagged
210216
211217
212- def add_text (state , model_selector , chat_input , request : gr .Request ):
213- text , images = chat_input ["text" ], chat_input ["files" ]
218+ def add_text (
219+ state ,
220+ model_selector ,
221+ chat_input : Union [str , dict ],
222+ context : Context ,
223+ request : gr .Request ,
224+ ):
225+ if isinstance (chat_input , dict ):
226+ text , images = chat_input ["text" ], chat_input ["files" ]
227+ else :
228+ text , images = chat_input , []
229+
230+ if (
231+ len (images ) > 0
232+ and model_selector in context .text_models
233+ and model_selector not in context .vision_models
234+ ):
235+ gr .Warning (f"{ model_selector } is a text-only model. Image is ignored." )
236+ images = []
237+
214238 ip = get_ip (request )
215239 logger .info (f"add_text. ip: { ip } . len: { len (text )} " )
216240
217241 if state is None :
218- state = State (model_selector , is_vision = True )
242+ if len (images ) == 0 :
243+ state = State (model_selector , is_vision = False )
244+ else :
245+ state = State (model_selector , is_vision = True )
219246
220247 if len (text ) <= 0 :
221248 state .skip_next = True
222- return (state , state .to_gradio_chatbot (), None ) + (no_change_btn ,) * 5
249+ return (state , state .to_gradio_chatbot (), None , "" , no_change_btn ) + (
250+ no_change_btn ,
251+ ) * 5
223252
224253 all_conv_text = state .conv .get_prompt ()
225254 all_conv_text = all_conv_text [- 2000 :] + "\n user: " + text
@@ -233,26 +262,40 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
233262 if image_flagged :
234263 logger .info (f"image flagged. ip: { ip } . text: { text } " )
235264 state .skip_next = True
236- return (state , state .to_gradio_chatbot (), {"text" : IMAGE_MODERATION_MSG }) + (
265+ return (
266+ state ,
267+ state .to_gradio_chatbot (),
268+ {"text" : IMAGE_MODERATION_MSG },
269+ "" ,
237270 no_change_btn ,
238- ) * 5
271+ ) + ( no_change_btn ,) * 5
239272
240273 if (len (state .conv .messages ) - state .conv .offset ) // 2 >= CONVERSATION_TURN_LIMIT :
241274 logger .info (f"conversation turn limit. ip: { ip } . text: { text } " )
242275 state .skip_next = True
243- return (state , state .to_gradio_chatbot (), {"text" : CONVERSATION_LIMIT_MSG }) + (
276+ return (
277+ state ,
278+ state .to_gradio_chatbot (),
279+ {"text" : CONVERSATION_LIMIT_MSG },
280+ "" ,
244281 no_change_btn ,
245- ) * 5
282+ ) + ( no_change_btn ,) * 5
246283
247284 text = text [:INPUT_CHAR_LEN_LIMIT ] # Hard cut-off
248285 text = _prepare_text_with_image (state , text , images , csam_flag = csam_flag )
249286 state .conv .append_message (state .conv .roles [0 ], text )
250287 state .conv .append_message (state .conv .roles [1 ], None )
251- return (state , state .to_gradio_chatbot (), None ) + (disable_btn ,) * 5
288+ return (
289+ state ,
290+ state .to_gradio_chatbot (),
291+ disable_multimodal ,
292+ visible_text ,
293+ enable_btn ,
294+ ) + (disable_btn ,) * 5
252295
253296
254297def build_single_vision_language_model_ui (
255- models , add_promotion_links = False , random_questions = None
298+ context : Context , add_promotion_links = False , random_questions = None
256299):
257300 promotion = (
258301 """
@@ -272,33 +315,29 @@ def build_single_vision_language_model_ui(
272315
273316 state = gr .State ()
274317 gr .Markdown (notice_markdown , elem_id = "notice_markdown" )
318+ text_and_vision_models = list (set (context .text_models + context .vision_models ))
319+ context_state = gr .State (context )
275320
276321 with gr .Group ():
277322 with gr .Row (elem_id = "model_selector_row" ):
278323 model_selector = gr .Dropdown (
279- choices = models ,
280- value = models [0 ] if len (models ) > 0 else "" ,
324+ choices = text_and_vision_models ,
325+ value = text_and_vision_models [0 ]
326+ if len (text_and_vision_models ) > 0
327+ else "" ,
281328 interactive = True ,
282329 show_label = False ,
283330 container = False ,
284331 )
285332
286333 with gr .Accordion (
287- f"🔍 Expand to see the descriptions of { len (models )} models" , open = False
334+ f"🔍 Expand to see the descriptions of { len (text_and_vision_models )} models" ,
335+ open = False ,
288336 ):
289- model_description_md = get_model_description_md (models )
337+ model_description_md = get_model_description_md (text_and_vision_models )
290338 gr .Markdown (model_description_md , elem_id = "model_description_markdown" )
291339
292340 with gr .Row ():
293- textbox = gr .MultimodalTextbox (
294- file_types = ["image" ],
295- show_label = False ,
296- placeholder = "Enter your prompt or add image here" ,
297- container = True ,
298- render = False ,
299- elem_id = "input_box" ,
300- )
301-
302341 with gr .Column (scale = 2 , visible = False ) as image_column :
303342 imagebox = gr .Image (
304343 type = "pil" ,
@@ -311,9 +350,24 @@ def build_single_vision_language_model_ui(
311350 )
312351
313352 with gr .Row ():
314- textbox .render ()
315- # with gr.Column(scale=1, min_width=50):
316- # send_btn = gr.Button(value="Send", variant="primary")
353+ textbox = gr .Textbox (
354+ show_label = False ,
355+ placeholder = "👉 Enter your prompt and press ENTER" ,
356+ elem_id = "input_box" ,
357+ visible = False ,
358+ )
359+
360+ send_btn = gr .Button (
361+ value = "Send" , variant = "primary" , scale = 0 , visible = False , interactive = False
362+ )
363+
364+ multimodal_textbox = gr .MultimodalTextbox (
365+ file_types = ["image" ],
366+ show_label = False ,
367+ placeholder = "Enter your prompt or add image here" ,
368+ container = True ,
369+ elem_id = "input_box" ,
370+ )
317371
318372 with gr .Row (elem_id = "buttons" ):
319373 if random_questions :
@@ -327,22 +381,6 @@ def build_single_vision_language_model_ui(
327381 regenerate_btn = gr .Button (value = "🔄 Regenerate" , interactive = False )
328382 clear_btn = gr .Button (value = "🗑️ Clear" , interactive = False )
329383
330- cur_dir = os .path .dirname (os .path .abspath (__file__ ))
331-
332- examples = gr .Examples (
333- examples = [
334- {
335- "text" : "How can I prepare a delicious meal using these ingredients?" ,
336- "files" : [f"{ cur_dir } /example_images/fridge.jpg" ],
337- },
338- {
339- "text" : "What might the woman on the right be thinking about?" ,
340- "files" : [f"{ cur_dir } /example_images/distracted.jpg" ],
341- },
342- ],
343- inputs = [textbox ],
344- )
345-
346384 with gr .Accordion ("Parameters" , open = False ) as parameter_row :
347385 temperature = gr .Slider (
348386 minimum = 0.0 ,
@@ -394,23 +432,50 @@ def build_single_vision_language_model_ui(
394432 [state , temperature , top_p , max_output_tokens ],
395433 [state , chatbot ] + btn_list ,
396434 )
397- clear_btn .click (clear_history , None , [state , chatbot , textbox ] + btn_list )
435+ clear_btn .click (
436+ clear_history ,
437+ None ,
438+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
439+ )
398440
399441 model_selector .change (
400- clear_history , None , [state , chatbot , textbox ] + btn_list
401- ).then (set_visible_image , [textbox ], [image_column ])
402- examples .dataset .click (
403- clear_history_example , None , [state , chatbot , textbox ] + btn_list
442+ clear_history ,
443+ None ,
444+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
445+ ).then (set_visible_image , [multimodal_textbox ], [image_column ])
446+
447+ multimodal_textbox .input (add_image , [multimodal_textbox ], [imagebox ]).then (
448+ set_visible_image , [multimodal_textbox ], [image_column ]
449+ ).then (
450+ clear_history_example ,
451+ None ,
452+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
404453 )
405454
406- textbox .input (add_image , [textbox ], [imagebox ]).then (
407- set_visible_image , [textbox ], [image_column ]
408- ).then (clear_history_example , None , [state , chatbot , textbox ] + btn_list )
455+ multimodal_textbox .submit (
456+ add_text ,
457+ [state , model_selector , multimodal_textbox , context_state ],
458+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
459+ ).then (set_invisible_image , [], [image_column ]).then (
460+ bot_response ,
461+ [state , temperature , top_p , max_output_tokens ],
462+ [state , chatbot ] + btn_list ,
463+ )
409464
410465 textbox .submit (
411466 add_text ,
412- [state , model_selector , textbox ],
413- [state , chatbot , textbox ] + btn_list ,
467+ [state , model_selector , textbox , context_state ],
468+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
469+ ).then (set_invisible_image , [], [image_column ]).then (
470+ bot_response ,
471+ [state , temperature , top_p , max_output_tokens ],
472+ [state , chatbot ] + btn_list ,
473+ )
474+
475+ send_btn .click (
476+ add_text ,
477+ [state , model_selector , textbox , context_state ],
478+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
414479 ).then (set_invisible_image , [], [image_column ]).then (
415480 bot_response ,
416481 [state , temperature , top_p , max_output_tokens ],
@@ -421,9 +486,11 @@ def build_single_vision_language_model_ui(
421486 random_btn .click (
422487 get_vqa_sample , # First, get the VQA sample
423488 [], # Pass the path to the VQA samples
424- [textbox , imagebox ], # Outputs are textbox and imagebox
425- ).then (set_visible_image , [textbox ], [image_column ]).then (
426- clear_history_example , None , [state , chatbot , textbox ] + btn_list
489+ [multimodal_textbox , imagebox ], # Outputs are textbox and imagebox
490+ ).then (set_visible_image , [multimodal_textbox ], [image_column ]).then (
491+ clear_history_example ,
492+ None ,
493+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
427494 )
428495
429496 return [state , model_selector ]
0 commit comments