Skip to content

Commit 3d6496c

Browse files
1. support deepspeed on ui 2. add tools to client_utils (#1446)
1 parent 2156e4c commit 3d6496c

File tree

8 files changed

+68
-14
lines changed

8 files changed

+68
-14
lines changed

swift/llm/utils/client_utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def _pre_inference_client(model_type: str,
168168
history: Optional[History] = None,
169169
system: Optional[str] = None,
170170
images: Optional[List[str]] = None,
171+
tools: Optional[List[Dict[str, Union[str, Dict]]]] = None,
172+
tool_choice: Optional[Union[str, Dict]] = 'auto',
171173
*,
172174
is_chat_request: Optional[bool] = None,
173175
request_config: Optional[XRequestConfig] = None,
@@ -212,7 +214,10 @@ def _pre_inference_client(model_type: str,
212214
data['model'] = model_type
213215
if len(images) > 0:
214216
data['images'] = images
215-
217+
if tools and len(tools) > 0:
218+
data['tools'] = tools
219+
if tool_choice:
220+
data['tool_choice'] = tool_choice
216221
return url, data, is_chat_request
217222

218223

@@ -222,6 +227,8 @@ def inference_client(
222227
history: Optional[History] = None,
223228
system: Optional[str] = None,
224229
images: Optional[List[str]] = None,
230+
tools: Optional[List[Dict[str, Union[str, Dict]]]] = None,
231+
tool_choice: Optional[Union[str, Dict]] = 'auto',
225232
*,
226233
is_chat_request: Optional[bool] = None,
227234
request_config: Optional[XRequestConfig] = None,
@@ -238,6 +245,8 @@ def inference_client(
238245
history,
239246
system,
240247
images,
248+
tools,
249+
tool_choice,
241250
is_chat_request=is_chat_request,
242251
request_config=request_config,
243252
host=host,
@@ -280,6 +289,8 @@ async def inference_client_async(
280289
history: Optional[History] = None,
281290
system: Optional[str] = None,
282291
images: Optional[List[str]] = None,
292+
tools: Optional[List[Dict[str, Union[str, Dict]]]] = None,
293+
tool_choice: Optional[Union[str, Dict]] = 'auto',
283294
*,
284295
is_chat_request: Optional[bool] = None,
285296
request_config: Optional[XRequestConfig] = None,
@@ -296,6 +307,8 @@ async def inference_client_async(
296307
history,
297308
system,
298309
images,
310+
tools,
311+
tool_choice,
299312
is_chat_request=is_chat_request,
300313
request_config=request_config,
301314
host=host,

swift/trainers/optimizers/galore/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, A
187187
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
188188
elif args.optim in ('adamw_hf', 'adamw_torch'):
189189
if config.quantize:
190-
assert importlib.util.find_spec("q_galore_torch") is not None, \
190+
assert importlib.util.find_spec('q_galore_torch') is not None, \
191191
'Please install q-galore by `pip install q_galore_torch`'
192192
from swift.utils import get_dist_setting
193193
_, _, world_size, _ = get_dist_setting()

swift/ui/llm_eval/llm_eval.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import torch
1212
from gradio import Accordion, Tab
13+
from json import JSONDecodeError
1314
from modelscope import snapshot_download
1415

1516
from swift.llm import EvalArguments
@@ -39,8 +40,8 @@ class LLMEval(BaseUI):
3940
'en': 'More params'
4041
},
4142
'info': {
42-
'zh': '以json格式填入',
43-
'en': 'Fill in with json format'
43+
'zh': '以json格式或--xxx xxx命令行格式填入',
44+
'en': 'Fill in with json format or --xxx xxx cmd format'
4445
}
4546
},
4647
'evaluate': {
@@ -113,6 +114,7 @@ def eval(cls, *args):
113114
kwargs_is_list = {}
114115
other_kwargs = {}
115116
more_params = {}
117+
more_params_cmd = ''
116118
keys = [key for key, value in cls.elements().items() if not isinstance(value, (Tab, Accordion))]
117119
for key, value in zip(keys, args):
118120
compare_value = eval_args.get(key)
@@ -130,7 +132,10 @@ def eval(cls, *args):
130132
else:
131133
other_kwargs[key] = value
132134
if key == 'more_params' and value:
133-
more_params = json.loads(value)
135+
try:
136+
more_params = json.loads(value)
137+
except (JSONDecodeError or TypeError):
138+
more_params_cmd = value
134139

135140
kwargs.update(more_params)
136141
if kwargs['model_type'] == cls.locale('checkpoint', cls.lang)['value']:
@@ -152,6 +157,7 @@ def eval(cls, *args):
152157
params += f'--{e} {kwargs[e]} '
153158
else:
154159
params += f'--{e} "{kwargs[e]}" '
160+
params += more_params_cmd + ' '
155161
devices = other_kwargs['gpu_id']
156162
devices = [d for d in devices if d]
157163
assert (len(devices) == 1 or 'cpu' not in devices)

swift/ui/llm_export/llm_export.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import torch
1212
from gradio import Accordion, Tab
13+
from json import JSONDecodeError
1314
from modelscope import snapshot_download
1415

1516
from swift.llm import ExportArguments
@@ -37,8 +38,8 @@ class LLMExport(BaseUI):
3738
'en': 'More params'
3839
},
3940
'info': {
40-
'zh': '以json格式填入',
41-
'en': 'Fill in with json format'
41+
'zh': '以json格式或--xxx xxx命令行格式填入',
42+
'en': 'Fill in with json format or --xxx xxx cmd format'
4243
}
4344
},
4445
'export': {
@@ -111,6 +112,7 @@ def export(cls, *args):
111112
kwargs_is_list = {}
112113
other_kwargs = {}
113114
more_params = {}
115+
more_params_cmd = ''
114116
keys = [key for key, value in cls.elements().items() if not isinstance(value, (Tab, Accordion))]
115117
for key, value in zip(keys, args):
116118
compare_value = export_args.get(key)
@@ -128,7 +130,10 @@ def export(cls, *args):
128130
else:
129131
other_kwargs[key] = value
130132
if key == 'more_params' and value:
131-
more_params = json.loads(value)
133+
try:
134+
more_params = json.loads(value)
135+
except (JSONDecodeError or TypeError):
136+
more_params_cmd = value
132137

133138
kwargs.update(more_params)
134139
if kwargs['model_type'] == cls.locale('checkpoint', cls.lang)['value']:
@@ -151,6 +156,7 @@ def export(cls, *args):
151156
params += f'--{e} {kwargs[e]} '
152157
else:
153158
params += f'--{e} "{kwargs[e]}" '
159+
params += more_params_cmd + ' '
154160
devices = other_kwargs['gpu_id']
155161
devices = [d for d in devices if d]
156162
assert (len(devices) == 1 or 'cpu' not in devices)

swift/ui/llm_infer/llm_infer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import torch
1212
from gradio import Accordion, Tab
13+
from json import JSONDecodeError
1314
from modelscope import GenerationConfig, snapshot_download
1415

1516
from swift.llm import (TEMPLATE_MAPPING, DeployArguments, InferArguments, XRequestConfig, inference_client,
@@ -215,6 +216,7 @@ def deploy(cls, *args):
215216
kwargs_is_list = {}
216217
other_kwargs = {}
217218
more_params = {}
219+
more_params_cmd = ''
218220
keys = [key for key, value in cls.elements().items() if not isinstance(value, (Tab, Accordion))]
219221
for key, value in zip(keys, args):
220222
compare_value = deploy_args.get(key)
@@ -232,7 +234,10 @@ def deploy(cls, *args):
232234
else:
233235
other_kwargs[key] = value
234236
if key == 'more_params' and value:
235-
more_params = json.loads(value)
237+
try:
238+
more_params = json.loads(value)
239+
except (JSONDecodeError or TypeError):
240+
more_params_cmd = value
236241

237242
kwargs.update(more_params)
238243
if kwargs['model_type'] == cls.locale('checkpoint', cls.lang)['value']:
@@ -263,6 +268,7 @@ def deploy(cls, *args):
263268
params += f'--{e} "{kwargs[e]}" '
264269
if 'port' not in kwargs:
265270
params += f'--port "{deploy_args.port}" '
271+
params += more_params_cmd + ' '
266272
devices = other_kwargs['gpu_id']
267273
devices = [d for d in devices if d]
268274
assert (len(devices) == 1 or 'cpu' not in devices)

swift/ui/llm_infer/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class Model(BaseUI):
9696
'en': 'More params'
9797
},
9898
'info': {
99-
'zh': '以json格式填入',
100-
'en': 'Fill in with json format'
99+
'zh': '以json格式或--xxx xxx命令行格式填入',
100+
'en': 'Fill in with json format or --xxx xxx cmd format'
101101
}
102102
},
103103
'reset': {

swift/ui/llm_train/advanced.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ class Advanced(BaseUI):
7272
'en': 'Other params'
7373
},
7474
'info': {
75-
'zh': '以json格式输入其他超参数',
76-
'en': 'Input in the json format'
75+
'zh': '以json格式或--xxx xxx命令行格式填入',
76+
'en': 'Fill in with json format or --xxx xxx cmd format'
7777
}
7878
},
7979
'custom_train_dataset_path': {
@@ -156,6 +156,16 @@ class Advanced(BaseUI):
156156
'en': 'Use model.generate/Rouge instead of loss',
157157
}
158158
},
159+
'deepspeed': {
160+
'label': {
161+
'zh': 'deepspeed',
162+
'en': 'deepspeed',
163+
},
164+
'info': {
165+
'zh': '可以选择下拉列表,也支持传入路径',
166+
'en': 'Choose from the dropbox or fill in a valid path',
167+
}
168+
},
159169
}
160170

161171
@classmethod
@@ -177,6 +187,11 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
177187
gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
178188
gr.Checkbox(elem_id='predict_with_generate', scale=20)
179189
with gr.Row():
190+
gr.Dropdown(
191+
elem_id='deepspeed',
192+
scale=4,
193+
allow_custom_value=True,
194+
choices=['default-zero2', 'default-zero3', 'zero3-offload'])
180195
gr.Textbox(elem_id='gpu_memory_fraction', scale=4)
181196
with gr.Row():
182197
gr.Textbox(elem_id='more_params', lines=4, scale=20)

swift/ui/llm_train/llm_train.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import json
1212
import torch
1313
from gradio import Accordion, Tab
14+
from json import JSONDecodeError
1415

1516
from swift.llm import RLHFArguments
1617
from swift.ui.base import BaseUI
@@ -294,6 +295,7 @@ def train(cls, *args):
294295
kwargs_is_list = {}
295296
other_kwargs = {}
296297
more_params = {}
298+
more_params_cmd = ''
297299
keys = [key for key, value in cls.elements().items() if not isinstance(value, (Tab, Accordion))]
298300
model_type = None
299301
do_rlhf = False
@@ -311,7 +313,10 @@ def train(cls, *args):
311313
else:
312314
other_kwargs[key] = value
313315
if key == 'more_params' and value:
314-
more_params = json.loads(value)
316+
try:
317+
more_params = json.loads(value)
318+
except (JSONDecodeError or TypeError):
319+
more_params_cmd = value
315320

316321
if key == 'model_type':
317322
model_type = value
@@ -327,6 +332,8 @@ def train(cls, *args):
327332
raise gr.Error(cls.locale('dataset_alert', cls.lang)['value'])
328333

329334
cmd = 'rlhf' if do_rlhf else 'sft'
335+
if kwargs.get('deepspeed'):
336+
more_params_cmd += f' --deepspeed {kwargs.pop("deepspeed")} '
330337
sft_args = RLHFArguments(
331338
**{
332339
key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
@@ -341,6 +348,7 @@ def train(cls, *args):
341348
params += f'--{e} {kwargs[e]} '
342349
else:
343350
params += f'--{e} "{kwargs[e]}" '
351+
params += more_params_cmd + ' '
344352
params += f'--add_output_dir_suffix False --output_dir {sft_args.output_dir} ' \
345353
f'--logging_dir {sft_args.logging_dir} --ignore_args_error True'
346354
ddp_param = ''

0 commit comments

Comments
 (0)