Skip to content

Commit 4511944

Browse files
Support studio (#300)
1 parent 178033d commit 4511944

File tree

12 files changed

+236
-21
lines changed

12 files changed

+236
-21
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Experimental environment: A100
2+
PYTHONPATH=../../.. \
3+
CUDA_VISIBLE_DEVICES=0 \
4+
python llm_infer.py \
5+
--ckpt_dir "output/llama2-13b-chat/vx_xxx/checkpoint-xxx" \
6+
--load_dataset_config true \
7+
--max_length 4096 \
8+
--max_new_tokens 2048 \
9+
--temperature 0.1 \
10+
--top_p 0.7 \
11+
--repetition_penalty 1.05 \
12+
--do_sample true \
13+
--merge_lora_and_save false \
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Experimental environment: 2 * A100
2+
# 2 * 37GB GPU memory
3+
nproc_per_node=2
4+
5+
PYTHONPATH=../../.. \
6+
CUDA_VISIBLE_DEVICES=0,1 \
7+
torchrun \
8+
--nproc_per_node=$nproc_per_node \
9+
--master_port 29500 \
10+
llm_sft.py \
11+
--model_id_or_path modelscope/Llama-2-13b-chat-ms \
12+
--model_revision master \
13+
--sft_type longlora \
14+
--tuner_backend swift \
15+
--template_type llama \
16+
--dtype AUTO \
17+
--output_dir output \
18+
--ddp_backend nccl \
19+
--dataset leetcode-python-en \
20+
--train_dataset_sample -1 \
21+
--num_train_epochs 1 \
22+
--max_length 4096 \
23+
--check_dataset_strategy warning \
24+
--lora_rank 8 \
25+
--lora_alpha 32 \
26+
--lora_dropout_p 0.05 \
27+
--lora_target_modules ALL \
28+
--gradient_checkpointing true \
29+
--batch_size 1 \
30+
--weight_decay 0.01 \
31+
--learning_rate 1e-4 \
32+
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
33+
--max_grad_norm 0.5 \
34+
--warmup_ratio 0.03 \
35+
--eval_steps 100 \
36+
--save_steps 100 \
37+
--save_total_limit 2 \
38+
--logging_steps 10 \
39+
--push_to_hub false \
40+
--hub_model_id llama2-13b-chat-longlora \
41+
--hub_private_repo true \
42+
--hub_token 'your-sdk-token' \
43+
--deepspeed_config_path 'ds_config/zero2.json' \
44+
--save_only_model true \
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Experimental environment: V100, A10, 3090
2+
PYTHONPATH=../../.. \
3+
CUDA_VISIBLE_DEVICES=0 \
4+
python llm_infer.py \
5+
--ckpt_dir "output/qwen-7b-chat-int4/vx_xxx/checkpoint-xxx" \
6+
--load_dataset_config true \
7+
--max_length 4096 \
8+
--use_flash_attn false \
9+
--max_new_tokens 2048 \
10+
--temperature 0.1 \
11+
--top_p 0.7 \
12+
--repetition_penalty 1.05 \
13+
--do_sample true \
14+
--merge_lora_and_save false \
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Experimental environment: V100, A10, 3090
2+
# 14GB GPU memory
3+
PYTHONPATH=../../.. \
4+
CUDA_VISIBLE_DEVICES=0 \
5+
python llm_sft.py \
6+
--model_id_or_path qwen/Qwen-7B-Chat-Int4 \
7+
--model_revision master \
8+
--sft_type qalora \
9+
--tuner_backend swift \
10+
--template_type qwen \
11+
--dtype fp16 \
12+
--output_dir output \
13+
--dataset leetcode-python-en \
14+
--train_dataset_sample -1 \
15+
--num_train_epochs 1 \
16+
--max_length 4096 \
17+
--check_dataset_strategy warning \
18+
--lora_rank 8 \
19+
--lora_alpha 32 \
20+
--lora_dropout_p 0.05 \
21+
--lora_target_modules ALL \
22+
--gradient_checkpointing true \
23+
--batch_size 1 \
24+
--weight_decay 0.01 \
25+
--learning_rate 1e-4 \
26+
--gradient_accumulation_steps 16 \
27+
--max_grad_norm 0.5 \
28+
--warmup_ratio 0.03 \
29+
--eval_steps 100 \
30+
--save_steps 100 \
31+
--save_total_limit 2 \
32+
--logging_steps 10 \
33+
--use_flash_attn false \
34+
--push_to_hub false \
35+
--hub_model_id qwen-7b-chat-int4-qalora \
36+
--hub_private_repo true \
37+
--hub_token 'your-sdk-token' \

swift/tuners/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def set_active_adapters(self,
461461
adapter_names: Union[List[str], str],
462462
offload=None):
463463
if not adapter_names:
464-
return
464+
adapter_names = []
465465

466466
if isinstance(adapter_names, str):
467467
adapter_names = [adapter_names]

swift/ui/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ def run_ui():
3434
LLMTrain.build_ui(LLMTrain)
3535
LLMInfer.build_ui(LLMInfer)
3636

37-
app.queue().launch(height=800, share=False)
37+
app.queue().launch(
38+
height=800, share=bool(os.environ.get('WEBUI_SHARE', '0')))

swift/ui/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from gradio import (Accordion, Button, Checkbox, Dropdown, Slider, Tab,
88
TabItem, Textbox)
99

10+
from swift.llm.utils.model import MODEL_MAPPING, ModelType
11+
1012
all_langs = ['zh', 'en']
1113
builder: Type['BaseUI'] = None
1214
base_builder: Type['BaseUI'] = None
@@ -168,3 +170,8 @@ def get_default_value_from_dataclass(dataclass):
168170
else:
169171
default_dict[f.name] = None
170172
return default_dict
173+
174+
@staticmethod
175+
def get_custom_name_list():
176+
return list(
177+
set(MODEL_MAPPING.keys()) - set(ModelType.get_model_name_list()))

swift/ui/llm_infer/llm_infer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import re
3-
from dataclasses import fields
43
from typing import Type
54

65
import gradio as gr
@@ -138,7 +137,6 @@ def reset_memory(cls):
138137

139138
@classmethod
140139
def prepare_checkpoint(cls, *args):
141-
global model, tokenizer, template
142140
torch.cuda.empty_cache()
143141
infer_args = cls.get_default_value_from_dataclass(InferArguments)
144142
kwargs = {}
@@ -201,6 +199,8 @@ def generate_chat(cls, model_and_template, template_type, prompt: str,
201199
gr.Warning(cls.locale('generate_alert', cls.lang)['value'])
202200
return '', None
203201
model, template = model_and_template
202+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
203+
model.cuda()
204204
if not template_type.endswith('generation'):
205205
old_history, history = limit_history_length(
206206
template, prompt, history, int(max_new_tokens))
@@ -211,3 +211,5 @@ def generate_chat(cls, model_and_template, template_type, prompt: str,
211211
for _, history in gen:
212212
total_history = old_history + history
213213
yield '', total_history
214+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
215+
model.cpu()

swift/ui/llm_infer/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
8686
model_type = gr.Dropdown(
8787
elem_id='model_type',
8888
choices=[base_tab.locale('checkpoint', cls.lang)['value']]
89-
+ ModelType.get_model_name_list(),
89+
+ ModelType.get_model_name_list() + cls.get_custom_name_list(),
9090
value=base_tab.locale('checkpoint', cls.lang)['value'],
9191
scale=20)
9292
model_id_or_path = gr.Textbox(

swift/ui/llm_train/llm_train.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import collections
12
import os
23
import sys
34
import time
5+
from subprocess import PIPE, STDOUT, Popen
46
from typing import Dict, Type
57

68
import gradio as gr
@@ -191,16 +193,31 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
191193
Quantization.build_ui(base_tab)
192194
SelfCog.build_ui(base_tab)
193195
Advanced.build_ui(base_tab)
194-
submit.click(
195-
cls.train, [
196-
value for value in cls.elements().values()
197-
if not isinstance(value, (Tab, Accordion))
198-
], [
199-
cls.element('running_cmd'),
200-
cls.element('logging_dir'),
201-
cls.element('runtime_tab')
202-
],
203-
show_progress=True)
196+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
197+
submit.click(
198+
cls.update_runtime, [],
199+
[cls.element('runtime_tab'),
200+
cls.element('log')]).then(
201+
cls.train_studio, [
202+
value for value in cls.elements().values()
203+
if not isinstance(value, (Tab, Accordion))
204+
], [cls.element('log')],
205+
queue=True)
206+
else:
207+
submit.click(
208+
cls.train_local, [
209+
value for value in cls.elements().values()
210+
if not isinstance(value, (Tab, Accordion))
211+
], [
212+
cls.element('running_cmd'),
213+
cls.element('logging_dir'),
214+
cls.element('runtime_tab'),
215+
],
216+
queue=True)
217+
218+
@classmethod
219+
def update_runtime(cls):
220+
return gr.update(visible=True), gr.update(visible=True)
204221

205222
@classmethod
206223
def train(cls, *args):
@@ -239,7 +256,8 @@ def train(cls, *args):
239256
params += f'--{e} {kwargs[e]} '
240257
else:
241258
params += f'--{e} "{kwargs[e]}" '
242-
params += '--add_output_dir_suffix False '
259+
params += f'--add_output_dir_suffix False --output_dir {sft_args.output_dir} ' \
260+
f'--logging_dir {sft_args.logging_dir}'
243261
for key, param in more_params.items():
244262
params += f'--{key} "{param}" '
245263
ddp_param = ''
@@ -260,9 +278,30 @@ def train(cls, *args):
260278
if ddp_param:
261279
ddp_param = f'set {ddp_param} && '
262280
run_command = f'{cuda_param}{ddp_param}start /b swift sft {params} > {log_file} 2>&1'
281+
elif os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
282+
run_command = f'{cuda_param} {ddp_param} swift sft {params}'
263283
else:
264284
run_command = f'{cuda_param} {ddp_param} nohup swift sft {params} > {log_file} 2>&1 &'
265285
logger.info(f'Run training: {run_command}')
286+
return run_command, sft_args, other_kwargs
287+
288+
@classmethod
289+
def train_studio(cls, *args):
290+
run_command, sft_args, other_kwargs = cls.train(*args)
291+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
292+
lines = collections.deque(
293+
maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
294+
process = Popen(
295+
run_command, shell=True, stdout=PIPE, stderr=STDOUT)
296+
with process.stdout:
297+
for line in iter(process.stdout.readline, b''):
298+
line = line.decode('utf-8')
299+
lines.append(line)
300+
yield '\n'.join(lines)
301+
302+
@classmethod
303+
def train_local(cls, *args):
304+
run_command, sft_args, other_kwargs = cls.train(*args)
266305
if not other_kwargs['dry_run']:
267306
os.makedirs(sft_args.logging_dir, exist_ok=True)
268307
os.system(run_command)

0 commit comments

Comments
 (0)