1313from gradio import Accordion , Tab
1414from modelscope import GenerationConfig , snapshot_download
1515
16- from swift .llm import (DeployArguments , InferArguments , XRequestConfig , inference_client , inference_stream ,
17- limit_history_length , prepare_model_template )
16+ from swift .llm import (TEMPLATE_MAPPING , DeployArguments , InferArguments , XRequestConfig , inference_client ,
17+ inference_stream , limit_history_length , prepare_model_template )
1818from swift .ui .base import BaseUI
1919from swift .ui .llm_infer .model import Model
2020from swift .ui .llm_infer .runtime import Runtime
@@ -349,7 +349,7 @@ def prepare_checkpoint(cls, *args):
349349
350350 @classmethod
351351 def clear_session (cls ):
352- return '' , [], None , []
352+ return '' , [], gr . update ( value = None , interactive = True ) , []
353353
354354 @classmethod
355355 def change_interactive (cls ):
@@ -365,6 +365,14 @@ def _replace_tag_with_media(cls, history):
365365 total_history .append (h [:2 ])
366366 return total_history
367367
368+ @classmethod
369+ def agent_type (cls , response ):
370+ if response .lower ().endswith ('observation:' ):
371+ return 'react'
372+ if 'observation:' not in response .lower () and 'action input:' in response .lower ():
373+ return 'toolbench'
374+ return None
375+
368376 @classmethod
369377 def send_message (cls , running_task , model_and_template , template_type , prompt : str , image , history , system ,
370378 max_new_tokens , temperature , top_k , top_p , repetition_penalty ):
@@ -393,20 +401,38 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
393401 request_config .stop = ['Observation:' ]
394402 stream_resp_with_history = ''
395403 medias = [m for h in old_history for m in h [2 ]]
404+ media_infer_type = TEMPLATE_MAPPING [template ].get ('infer_media_type' , 'round' )
405+ image_interactive = media_infer_type != 'dialogue'
406+
407+ text_history = [h for h in old_history if h [0 ]]
408+ roles = []
409+ for i in range (len (text_history ) + 1 ):
410+ roles .append (['user' , 'assistant' ])
411+
412+ for i , h in enumerate (text_history ):
413+ agent_type = cls .agent_type (h [1 ])
414+ if i < len (text_history ) - 1 and agent_type == 'toolbench' :
415+ roles [i + 1 ][0 ] = 'tool'
416+ if i == len (text_history ) - 1 and agent_type in ('toolbench' , 'react' ):
417+ roles [i + 1 ][0 ] = 'tool'
418+
396419 if not template_type .endswith ('generation' ):
397420 stream_resp = inference_client (
398421 model_type ,
399422 prompt ,
400423 images = medias ,
401- history = [h [:2 ] for h in old_history if h [ 0 ] ],
424+ history = [h [:2 ] for h in text_history ],
402425 system = system ,
403426 port = args ['port' ],
404- request_config = request_config )
427+ request_config = request_config ,
428+ roles = roles ,
429+ )
405430 for chunk in stream_resp :
406431 stream_resp_with_history += chunk .choices [0 ].delta .content
407432 old_history [- 1 ][0 ] = prompt
408433 old_history [- 1 ][1 ] = stream_resp_with_history
409- yield '' , cls ._replace_tag_with_media (old_history ), None , old_history
434+ yield ('' , cls ._replace_tag_with_media (old_history ),
435+ gr .update (value = None , interactive = image_interactive ), old_history )
410436 else :
411437 request_config .max_tokens = max_new_tokens
412438 stream_resp = inference_client (
@@ -415,7 +441,8 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
415441 stream_resp_with_history += chunk .choices [0 ].text
416442 old_history [- 1 ][0 ] = prompt
417443 old_history [- 1 ][1 ] = stream_resp_with_history
418- yield '' , cls ._replace_tag_with_media (old_history ), None , old_history
444+ yield ('' , cls ._replace_tag_with_media (old_history ),
445+ gr .update (value = None , interactive = image_interactive ), old_history )
419446
420447 @classmethod
421448 def generate_chat (cls , model_and_template , template_type , prompt : str , image , history , system , max_new_tokens ,
0 commit comments