22import re
33import sys
44import time
5- from copy import copy
65from datetime import datetime
76from functools import partial
87from typing import Type
1413from modelscope import GenerationConfig , snapshot_download
1514
1615from swift .llm import (TEMPLATE_MAPPING , DeployArguments , InferArguments , XRequestConfig , inference_client ,
17- inference_stream , limit_history_length , prepare_model_template )
16+ inference_stream , prepare_model_template )
1817from swift .ui .base import BaseUI
1918from swift .ui .llm_infer .model import Model
2019from swift .ui .llm_infer .runtime import Runtime
@@ -69,6 +68,16 @@ class LLMInfer(BaseUI):
6968 'en' : 'Chat bot'
7069 },
7170 },
71+ 'infer_model_type' : {
72+ 'label' : {
73+ 'zh' : 'Lora模块' ,
74+ 'en' : 'Lora module'
75+ },
76+ 'info' : {
77+ 'zh' : '发送给server端哪个LoRA,默认为`default-lora`' ,
78+ 'en' : 'Which LoRA to use on server, default value is `default-lora`'
79+ }
80+ },
7281 'prompt' : {
7382 'label' : {
7483 'zh' : '请输入:' ,
@@ -116,12 +125,14 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
116125 history = gr .State ([])
117126 Model .build_ui (base_tab )
118127 Runtime .build_ui (base_tab )
119- gr .Dropdown (
120- elem_id = 'gpu_id' ,
121- multiselect = True ,
122- choices = [str (i ) for i in range (gpu_count )] + ['cpu' ],
123- value = default_device ,
124- scale = 8 )
128+ with gr .Row ():
129+ gr .Dropdown (
130+ elem_id = 'gpu_id' ,
131+ multiselect = True ,
132+ choices = [str (i ) for i in range (gpu_count )] + ['cpu' ],
133+ value = default_device ,
134+ scale = 8 )
135+ infer_model_type = gr .Textbox (elem_id = 'infer_model_type' , scale = 4 )
125136 chatbot = gr .Chatbot (elem_id = 'chatbot' , elem_classes = 'control-height' )
126137 with gr .Row ():
127138 prompt = gr .Textbox (elem_id = 'prompt' , lines = 1 , interactive = True )
@@ -172,7 +183,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
172183 cls .send_message ,
173184 inputs = [
174185 cls .element ('running_tasks' ), model_and_template ,
175- cls .element ('template_type' ), prompt , image , history ,
186+ cls .element ('template_type' ), prompt , image , history , infer_model_type ,
176187 cls .element ('system' ),
177188 cls .element ('max_new_tokens' ),
178189 cls .element ('temperature' ),
@@ -217,7 +228,7 @@ def deploy(cls, *args):
217228 elif isinstance (value , str ) and re .fullmatch (cls .bool_regex , value ):
218229 value = True if value .lower () == 'true' else False
219230 kwargs [key ] = value if not isinstance (value , list ) else ' ' .join (value )
220- kwargs_is_list [key ] = isinstance (value , list )
231+ kwargs_is_list [key ] = isinstance (value , list ) or getattr ( cls . element ( key ), 'is_list' , False )
221232 else :
222233 other_kwargs [key ] = value
223234 if key == 'more_params' and value :
@@ -374,8 +385,8 @@ def agent_type(cls, response):
374385 return None
375386
376387 @classmethod
377- def send_message (cls , running_task , model_and_template , template_type , prompt : str , image , history , system ,
378- max_new_tokens , temperature , top_k , top_p , repetition_penalty ):
388+ def send_message (cls , running_task , model_and_template , template_type , prompt : str , image , history ,
389+ infer_model_type , system , max_new_tokens , temperature , top_k , top_p , repetition_penalty ):
379390 if not model_and_template :
380391 gr .Warning (cls .locale ('generate_alert' , cls .lang )['value' ])
381392 return '' , None , None , []
@@ -393,7 +404,7 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
393404 _ , args = Runtime .parse_info_from_cmdline (running_task )
394405 model_type , template , sft_type = model_and_template
395406 if sft_type in ('lora' , 'longlora' ) and not args .get ('merge_lora' ):
396- model_type = 'default-lora'
407+ model_type = infer_model_type or 'default-lora'
397408 old_history , history = history or [], []
398409 request_config = XRequestConfig (
399410 temperature = temperature , top_k = top_k , top_p = top_p , repetition_penalty = repetition_penalty )
0 commit comments