1010import json
1111import os
1212import time
13+ from typing import List , Union
1314
1415import gradio as gr
1516from gradio .data_classes import FileData
2728from fastchat .model .model_adapter import (
2829 get_conversation_template ,
2930)
31+ from fastchat .serve .gradio_global_state import Context
3032from fastchat .serve .gradio_web_server import (
3133 get_model_description_md ,
3234 acknowledgment_md ,
@@ -153,14 +155,14 @@ def clear_history(request: gr.Request):
153155 ip = get_ip (request )
154156 logger .info (f"clear_history. ip: { ip } " )
155157 state = None
156- return (state , [], enable_multimodal_clear_input ) + (disable_btn ,) * 5
158+ return (state , [], enable_multimodal_clear_input , invisible_text , invisible_btn ) + (disable_btn ,) * 5
157159
158160
159161def clear_history_example (request : gr .Request ):
160162 ip = get_ip (request )
161163 logger .info (f"clear_history_example. ip: { ip } " )
162164 state = None
163- return (state , [], enable_multimodal_keep_input ) + (disable_btn ,) * 5
165+ return (state , [], enable_multimodal_keep_input , invisible_text , invisible_btn ) + (disable_btn ,) * 5
164166
165167
166168# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -199,11 +201,16 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
199201 logger .info (f"add_text. ip: { ip } . len: { len (text )} " )
200202
201203 if state is None :
202- state = State (model_selector , is_vision = True )
204+ if len (images ) == 0 :
205+ state = State (model_selector , is_vision = False )
206+ else :
207+ state = State (model_selector , is_vision = True )
203208
204209 if len (text ) <= 0 :
205210 state .skip_next = True
206- return (state , state .to_gradio_chatbot (), None ) + (no_change_btn ,) * 5
211+ return (state , state .to_gradio_chatbot (), None , "" , no_change_btn ) + (
212+ no_change_btn ,
213+ ) * 5
207214
208215 all_conv_text = state .conv .get_prompt ()
209216 all_conv_text = all_conv_text [- 2000 :] + "\n user: " + text
@@ -239,19 +246,29 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
239246 if (len (state .conv .messages ) - state .conv .offset ) // 2 >= CONVERSATION_TURN_LIMIT :
240247 logger .info (f"conversation turn limit. ip: { ip } . text: { text } " )
241248 state .skip_next = True
242- return (state , state .to_gradio_chatbot (), {"text" : CONVERSATION_LIMIT_MSG }) + (
249+ return (
250+ state ,
251+ state .to_gradio_chatbot (),
252+ {"text" : CONVERSATION_LIMIT_MSG },
253+ "" ,
243254 no_change_btn ,
244- ) * 5
255+ ) + ( no_change_btn ,) * 5
245256
246257 text = text [:INPUT_CHAR_LEN_LIMIT ] # Hard cut-off
247258 text = _prepare_text_with_image (state , text , images )
248259 state .conv .append_message (state .conv .roles [0 ], text )
249260 state .conv .append_message (state .conv .roles [1 ], None )
250- return (state , state .to_gradio_chatbot (), None ) + (disable_btn ,) * 5
261+ return (
262+ state ,
263+ state .to_gradio_chatbot (),
264+ disable_multimodal ,
265+ visible_text ,
266+ enable_btn ,
267+ ) + (disable_btn ,) * 5
251268
252269
253270def build_single_vision_language_model_ui (
254- models , add_promotion_links = False , random_questions = None
271+ context : Context , add_promotion_links = False , random_questions = None
255272):
256273 promotion = (
257274 f"""
@@ -273,33 +290,29 @@ def build_single_vision_language_model_ui(
273290
274291 state = gr .State ()
275292 gr .Markdown (notice_markdown , elem_id = "notice_markdown" )
293+ text_and_vision_models = list (set (context .text_models + context .vision_models ))
294+ context_state = gr .State (context )
276295
277296 with gr .Group ():
278297 with gr .Row (elem_id = "model_selector_row" ):
279298 model_selector = gr .Dropdown (
280- choices = models ,
281- value = models [0 ] if len (models ) > 0 else "" ,
299+ choices = text_and_vision_models ,
300+ value = text_and_vision_models [0 ]
301+ if len (text_and_vision_models ) > 0
302+ else "" ,
282303 interactive = True ,
283304 show_label = False ,
284305 container = False ,
285306 )
286307
287308 with gr .Accordion (
288- f"🔍 Expand to see the descriptions of { len (models )} models" , open = False
309+ f"🔍 Expand to see the descriptions of { len (text_and_vision_models )} models" ,
310+ open = False ,
289311 ):
290- model_description_md = get_model_description_md (models )
312+ model_description_md = get_model_description_md (text_and_vision_models )
291313 gr .Markdown (model_description_md , elem_id = "model_description_markdown" )
292314
293315 with gr .Row ():
294- textbox = gr .MultimodalTextbox (
295- file_types = ["image" ],
296- show_label = False ,
297- placeholder = "Enter your prompt or add image here" ,
298- container = True ,
299- render = False ,
300- elem_id = "input_box" ,
301- )
302-
303316 with gr .Column (scale = 2 , visible = False ) as image_column :
304317 imagebox = gr .Image (
305318 type = "pil" ,
@@ -312,9 +325,24 @@ def build_single_vision_language_model_ui(
312325 )
313326
314327 with gr .Row ():
315- textbox .render ()
316- # with gr.Column(scale=1, min_width=50):
317- # send_btn = gr.Button(value="Send", variant="primary")
328+ textbox = gr .Textbox (
329+ show_label = False ,
330+ placeholder = "👉 Enter your prompt and press ENTER" ,
331+ elem_id = "input_box" ,
332+ visible = False ,
333+ )
334+
335+ send_btn = gr .Button (
336+ value = "Send" , variant = "primary" , scale = 0 , visible = False , interactive = False
337+ )
338+
339+ multimodal_textbox = gr .MultimodalTextbox (
340+ file_types = ["image" ],
341+ show_label = False ,
342+ placeholder = "Enter your prompt or add image here" ,
343+ container = True ,
344+ elem_id = "input_box" ,
345+ )
318346
319347 with gr .Row (elem_id = "buttons" ):
320348 if random_questions :
@@ -328,22 +356,6 @@ def build_single_vision_language_model_ui(
328356 regenerate_btn = gr .Button (value = "🔄 Regenerate" , interactive = False )
329357 clear_btn = gr .Button (value = "🗑️ Clear" , interactive = False )
330358
331- cur_dir = os .path .dirname (os .path .abspath (__file__ ))
332-
333- examples = gr .Examples (
334- examples = [
335- {
336- "text" : "How can I prepare a delicious meal using these ingredients?" ,
337- "files" : [f"{ cur_dir } /example_images/fridge.jpg" ],
338- },
339- {
340- "text" : "What might the woman on the right be thinking about?" ,
341- "files" : [f"{ cur_dir } /example_images/distracted.jpg" ],
342- },
343- ],
344- inputs = [textbox ],
345- )
346-
347359 with gr .Accordion ("Parameters" , open = False ) as parameter_row :
348360 temperature = gr .Slider (
349361 minimum = 0.0 ,
@@ -395,23 +407,50 @@ def build_single_vision_language_model_ui(
395407 [state , temperature , top_p , max_output_tokens ],
396408 [state , chatbot ] + btn_list ,
397409 )
398- clear_btn .click (clear_history , None , [state , chatbot , textbox ] + btn_list )
410+ clear_btn .click (
411+ clear_history ,
412+ None ,
413+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
414+ )
399415
400416 model_selector .change (
401- clear_history , None , [state , chatbot , textbox ] + btn_list
402- ).then (set_visible_image , [textbox ], [image_column ])
403- examples .dataset .click (
404- clear_history_example , None , [state , chatbot , textbox ] + btn_list
417+ clear_history ,
418+ None ,
419+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
420+ ).then (set_visible_image , [multimodal_textbox ], [image_column ])
421+
422+ multimodal_textbox .input (add_image , [multimodal_textbox ], [imagebox ]).then (
423+ set_visible_image , [multimodal_textbox ], [image_column ]
424+ ).then (
425+ clear_history_example ,
426+ None ,
427+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
405428 )
406429
407- textbox .input (add_image , [textbox ], [imagebox ]).then (
408- set_visible_image , [textbox ], [image_column ]
409- ).then (clear_history_example , None , [state , chatbot , textbox ] + btn_list )
430+ multimodal_textbox .submit (
431+ add_text ,
432+ [state , model_selector , multimodal_textbox , context_state ],
433+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
434+ ).then (set_invisible_image , [], [image_column ]).then (
435+ bot_response ,
436+ [state , temperature , top_p , max_output_tokens ],
437+ [state , chatbot ] + btn_list ,
438+ )
410439
411440 textbox .submit (
412441 add_text ,
413- [state , model_selector , textbox ],
414- [state , chatbot , textbox ] + btn_list ,
442+ [state , model_selector , textbox , context_state ],
443+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
444+ ).then (set_invisible_image , [], [image_column ]).then (
445+ bot_response ,
446+ [state , temperature , top_p , max_output_tokens ],
447+ [state , chatbot ] + btn_list ,
448+ )
449+
450+ send_btn .click (
451+ add_text ,
452+ [state , model_selector , textbox , context_state ],
453+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
415454 ).then (set_invisible_image , [], [image_column ]).then (
416455 bot_response ,
417456 [state , temperature , top_p , max_output_tokens ],
@@ -422,9 +461,11 @@ def build_single_vision_language_model_ui(
422461 random_btn .click (
423462 get_vqa_sample , # First, get the VQA sample
424463 [], # Pass the path to the VQA samples
425- [textbox , imagebox ], # Outputs are textbox and imagebox
426- ).then (set_visible_image , [textbox ], [image_column ]).then (
427- clear_history_example , None , [state , chatbot , textbox ] + btn_list
464+ [multimodal_textbox , imagebox ], # Outputs are textbox and imagebox
465+ ).then (set_visible_image , [multimodal_textbox ], [image_column ]).then (
466+ clear_history_example ,
467+ None ,
468+ [state , chatbot , multimodal_textbox , textbox , send_btn ] + btn_list ,
428469 )
429470
430471 return [state , model_selector ]
0 commit comments