Skip to content

Commit c8b1308

Browse files
Refactor UI and fix some bugs (#1300)
1 parent 7b57fb2 commit c8b1308

File tree

10 files changed

+240
-164
lines changed

10 files changed

+240
-164
lines changed

swift/llm/utils/client_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _pre_inference_client(model_type: str,
169169
url = f'http://{host}:{port}/v1'
170170
url = url.rstrip('/')
171171
if is_chat_request:
172-
messages = history_to_messages(history, query, system)
172+
messages = history_to_messages(history, query, system, kwargs.get('roles'))
173173
if is_multimodal:
174174
messages = convert_to_base64(messages=messages)['messages']
175175
images = convert_to_base64(images=images)['images']

swift/llm/utils/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -820,18 +820,22 @@ def compute_token_length(history_length: int) -> int:
820820

821821
def history_to_messages(history: Optional[History],
822822
query: Optional[str] = None,
823-
system: Optional[str] = None) -> Messages:
823+
system: Optional[str] = None,
824+
roles: Optional[List[List[str]]] = None) -> Messages:
824825
if history is None:
825826
history = []
826827
messages = []
828+
if not roles:
829+
roles = [['user', 'assistant']] * (len(history) + 1)
830+
assert len(roles) == len(history) + 1
827831
if system is not None:
828832
messages.append({'role': 'system', 'content': system})
829-
for h in history:
833+
for role, h in zip(roles, history):
830834
assert isinstance(h, (list, tuple))
831-
messages.append({'role': 'user', 'content': h[0]})
832-
messages.append({'role': 'assistant', 'content': h[1]})
835+
messages.append({'role': role[0], 'content': h[0]})
836+
messages.append({'role': role[1], 'content': h[1]})
833837
if query is not None:
834-
messages.append({'role': 'user', 'content': query})
838+
messages.append({'role': roles[-1][0], 'content': query})
835839
return messages
836840

837841

swift/ui/llm_eval/eval.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import gradio as gr
55

66
from swift.ui.base import BaseUI
7+
from swift.utils import get_logger
8+
9+
logger = get_logger()
710

811

912
class Eval(BaseUI):
@@ -110,10 +113,22 @@ class Eval(BaseUI):
110113

111114
@classmethod
112115
def do_build_ui(cls, base_tab: Type['BaseUI']):
116+
try:
117+
from llmuses.backend.opencompass import OpenCompassBackendManager
118+
except ImportError as e:
119+
logger.error('You are using web-ui, please '
120+
'install requirements by `pip install llmuses ms-opencompass -U`')
121+
raise e
122+
113123
with gr.Row():
114124
gr.Textbox(elem_id='name', scale=20)
115125
gr.Dropdown(
116-
elem_id='eval_dataset', is_list=True, choices=['ceval', 'gsm8k', 'arc'], multiselect=True, scale=20)
126+
elem_id='eval_dataset',
127+
is_list=True,
128+
choices=OpenCompassBackendManager.list_datasets(),
129+
multiselect=True,
130+
allow_custom_value=True,
131+
scale=20)
117132
gr.Textbox(elem_id='eval_few_shot', scale=20)
118133
gr.Textbox(elem_id='eval_limit', scale=20)
119134
gr.Checkbox(elem_id='eval_use_cache', scale=20)

swift/ui/llm_infer/llm_infer.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from gradio import Accordion, Tab
1414
from modelscope import GenerationConfig, snapshot_download
1515

16-
from swift.llm import (DeployArguments, InferArguments, XRequestConfig, inference_client, inference_stream,
17-
limit_history_length, prepare_model_template)
16+
from swift.llm import (TEMPLATE_MAPPING, DeployArguments, InferArguments, XRequestConfig, inference_client,
17+
inference_stream, limit_history_length, prepare_model_template)
1818
from swift.ui.base import BaseUI
1919
from swift.ui.llm_infer.model import Model
2020
from swift.ui.llm_infer.runtime import Runtime
@@ -349,7 +349,7 @@ def prepare_checkpoint(cls, *args):
349349

350350
@classmethod
351351
def clear_session(cls):
352-
return '', [], None, []
352+
return '', [], gr.update(value=None, interactive=True), []
353353

354354
@classmethod
355355
def change_interactive(cls):
@@ -365,6 +365,14 @@ def _replace_tag_with_media(cls, history):
365365
total_history.append(h[:2])
366366
return total_history
367367

368+
@classmethod
369+
def agent_type(cls, response):
370+
if response.lower().endswith('observation:'):
371+
return 'react'
372+
if 'observation:' not in response.lower() and 'action input:' in response.lower():
373+
return 'toolbench'
374+
return None
375+
368376
@classmethod
369377
def send_message(cls, running_task, model_and_template, template_type, prompt: str, image, history, system,
370378
max_new_tokens, temperature, top_k, top_p, repetition_penalty):
@@ -393,20 +401,38 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
393401
request_config.stop = ['Observation:']
394402
stream_resp_with_history = ''
395403
medias = [m for h in old_history for m in h[2]]
404+
media_infer_type = TEMPLATE_MAPPING[template].get('infer_media_type', 'round')
405+
image_interactive = media_infer_type != 'dialogue'
406+
407+
text_history = [h for h in old_history if h[0]]
408+
roles = []
409+
for i in range(len(text_history) + 1):
410+
roles.append(['user', 'assistant'])
411+
412+
for i, h in enumerate(text_history):
413+
agent_type = cls.agent_type(h[1])
414+
if i < len(text_history) - 1 and agent_type == 'toolbench':
415+
roles[i + 1][0] = 'tool'
416+
if i == len(text_history) - 1 and agent_type in ('toolbench', 'react'):
417+
roles[i + 1][0] = 'tool'
418+
396419
if not template_type.endswith('generation'):
397420
stream_resp = inference_client(
398421
model_type,
399422
prompt,
400423
images=medias,
401-
history=[h[:2] for h in old_history if h[0]],
424+
history=[h[:2] for h in text_history],
402425
system=system,
403426
port=args['port'],
404-
request_config=request_config)
427+
request_config=request_config,
428+
roles=roles,
429+
)
405430
for chunk in stream_resp:
406431
stream_resp_with_history += chunk.choices[0].delta.content
407432
old_history[-1][0] = prompt
408433
old_history[-1][1] = stream_resp_with_history
409-
yield '', cls._replace_tag_with_media(old_history), None, old_history
434+
yield ('', cls._replace_tag_with_media(old_history),
435+
gr.update(value=None, interactive=image_interactive), old_history)
410436
else:
411437
request_config.max_tokens = max_new_tokens
412438
stream_resp = inference_client(
@@ -415,7 +441,8 @@ def send_message(cls, running_task, model_and_template, template_type, prompt: s
415441
stream_resp_with_history += chunk.choices[0].text
416442
old_history[-1][0] = prompt
417443
old_history[-1][1] = stream_resp_with_history
418-
yield '', cls._replace_tag_with_media(old_history), None, old_history
444+
yield ('', cls._replace_tag_with_media(old_history),
445+
gr.update(value=None, interactive=image_interactive), old_history)
419446

420447
@classmethod
421448
def generate_chat(cls, model_and_template, template_type, prompt: str, image, history, system, max_new_tokens,

swift/ui/llm_train/advanced.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,86 @@ class Advanced(BaseUI):
7676
'en': 'Input in the json format'
7777
}
7878
},
79+
'custom_train_dataset_path': {
80+
'label': {
81+
'zh': '自定义训练数据集路径',
82+
'en': 'Custom train dataset path'
83+
},
84+
'info': {
85+
'zh': '输入自定义的训练数据集路径,空格分隔',
86+
'en': 'Extra train files, split by blank'
87+
}
88+
},
89+
'custom_val_dataset_path': {
90+
'label': {
91+
'zh': '自定义校验数据集路径',
92+
'en': 'Custom val dataset path'
93+
},
94+
'info': {
95+
'zh': '输入自定义的校验数据集路径,逗号分隔',
96+
'en': 'Extra val files, split by comma'
97+
}
98+
},
99+
'truncation_strategy': {
100+
'label': {
101+
'zh': '数据集超长策略',
102+
'en': 'Dataset truncation strategy'
103+
},
104+
'info': {
105+
'zh': '如果token超长该如何处理',
106+
'en': 'How to deal with the rows exceed the max length'
107+
}
108+
},
109+
'gpu_memory_fraction': {
110+
'label': {
111+
'zh': 'GPU显存限制',
112+
'en': 'GPU memory fraction'
113+
},
114+
'info': {
115+
'zh': '设置使用显存的比例,一般用于显存测试',
116+
'en': 'Set the memory fraction ratio of GPU, usually used in memory test'
117+
}
118+
},
119+
'max_steps': {
120+
'label': {
121+
'zh': '最大迭代步数',
122+
'en': 'Max steps',
123+
},
124+
'info': {
125+
'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
126+
'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
127+
}
128+
},
129+
'eval_batch_size': {
130+
'label': {
131+
'zh': '验证batch size',
132+
'en': 'Val batch size',
133+
},
134+
'info': {
135+
'zh': '验证的batch size',
136+
'en': 'Set the val batch size',
137+
}
138+
},
139+
'max_grad_norm': {
140+
'label': {
141+
'zh': '梯度裁剪',
142+
'en': 'Max grad norm',
143+
},
144+
'info': {
145+
'zh': '设置梯度裁剪',
146+
'en': 'Set the max grad norm',
147+
}
148+
},
149+
'predict_with_generate': {
150+
'label': {
151+
'zh': '使用生成指标代替loss',
152+
'en': 'Use generate metric instead of loss',
153+
},
154+
'info': {
155+
'zh': '验证时使用generate/Rouge代替loss',
156+
'en': 'Use model.generate/Rouge instead of loss',
157+
}
158+
},
79159
}
80160

81161
@classmethod
@@ -87,6 +167,16 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
87167
gr.Textbox(elem_id='weight_decay', lines=1, scale=20)
88168
gr.Textbox(elem_id='logging_steps', lines=1, scale=20)
89169
gr.Textbox(elem_id='lr_scheduler_type', lines=1, scale=20)
170+
gr.Textbox(elem_id='max_steps', lines=1, scale=20)
90171
gr.Slider(elem_id='warmup_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
172+
with gr.Row():
173+
gr.Textbox(elem_id='custom_train_dataset_path', is_list=True, scale=20)
174+
gr.Textbox(elem_id='custom_val_dataset_path', is_list=True, scale=20)
175+
gr.Dropdown(elem_id='truncation_strategy', scale=20)
176+
gr.Slider(elem_id='eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
177+
gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
178+
gr.Checkbox(elem_id='predict_with_generate', scale=20)
179+
with gr.Row():
180+
gr.Textbox(elem_id='gpu_memory_fraction', scale=4)
91181
with gr.Row():
92182
gr.Textbox(elem_id='more_params', lines=4, scale=20)

swift/ui/llm_train/dataset.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,6 @@ class Dataset(BaseUI):
3232
'en': 'Set the max length input to the model',
3333
}
3434
},
35-
'custom_train_dataset_path': {
36-
'label': {
37-
'zh': '自定义训练数据集路径',
38-
'en': 'Custom train dataset path'
39-
},
40-
'info': {
41-
'zh': '输入自定义的训练数据集路径,空格分隔',
42-
'en': 'Extra train files, split by blank'
43-
}
44-
},
45-
'custom_val_dataset_path': {
46-
'label': {
47-
'zh': '自定义校验数据集路径',
48-
'en': 'Custom val dataset path'
49-
},
50-
'info': {
51-
'zh': '输入自定义的校验数据集路径,逗号分隔',
52-
'en': 'Extra val files, split by comma'
53-
}
54-
},
5535
'dataset_test_ratio': {
5636
'label': {
5737
'zh': '验证集拆分比例',
@@ -82,16 +62,6 @@ class Dataset(BaseUI):
8262
'en': 'Validate with the sample size from the dataset',
8363
}
8464
},
85-
'truncation_strategy': {
86-
'label': {
87-
'zh': '数据集超长策略',
88-
'en': 'Dataset truncation strategy'
89-
},
90-
'info': {
91-
'zh': '如果token超长该如何处理',
92-
'en': 'How to deal with the rows exceed the max length'
93-
}
94-
},
9565
'custom_dataset_info': {
9666
'label': {
9767
'zh': '外部数据集配置',
@@ -102,18 +72,27 @@ class Dataset(BaseUI):
10272
'en': 'An extra dataset config to register your own datasets'
10373
}
10474
},
75+
'dataset_param': {
76+
'label': {
77+
'zh': '数据集设置',
78+
'en': 'Dataset settings'
79+
},
80+
},
10581
}
10682

10783
@classmethod
10884
def do_build_ui(cls, base_tab: Type['BaseUI']):
109-
with gr.Row():
110-
gr.Dropdown(elem_id='dataset', multiselect=True, choices=list(DATASET_MAPPING.keys()), scale=20)
111-
gr.Textbox(elem_id='custom_dataset_info', is_list=False, scale=20)
112-
gr.Textbox(elem_id='custom_train_dataset_path', is_list=True, scale=20)
113-
gr.Textbox(elem_id='custom_val_dataset_path', is_list=True, scale=20)
114-
with gr.Row():
115-
gr.Slider(elem_id='dataset_test_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
116-
gr.Slider(elem_id='max_length', minimum=32, maximum=32768, step=32, scale=20)
117-
gr.Textbox(elem_id='train_dataset_sample', scale=20)
118-
gr.Textbox(elem_id='val_dataset_sample', scale=20)
119-
gr.Dropdown(elem_id='truncation_strategy', scale=20)
85+
with gr.Accordion(elem_id='dataset_param', open=True):
86+
with gr.Row():
87+
gr.Dropdown(
88+
elem_id='dataset',
89+
multiselect=True,
90+
choices=list(DATASET_MAPPING.keys()),
91+
scale=20,
92+
allow_custom_value=True)
93+
gr.Textbox(elem_id='custom_dataset_info', is_list=False, scale=20)
94+
with gr.Row():
95+
gr.Slider(elem_id='dataset_test_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
96+
gr.Slider(elem_id='max_length', minimum=32, maximum=32768, step=1, scale=20)
97+
gr.Textbox(elem_id='train_dataset_sample', scale=20)
98+
gr.Textbox(elem_id='val_dataset_sample', scale=20)

0 commit comments

Comments
 (0)