Skip to content

Commit c5ad635

Browse files
authored
update code (#169)
1 parent 4dea862 commit c5ad635

File tree

16 files changed

+96
-78
lines changed

16 files changed

+96
-78
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
6666

6767

6868
## ✨ LLM SFT Example
69-
The detailed usage documentation for fine-tuning LLM can be found [here](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm).
69+
Users can refer to the [LLM fine-tuning documentation](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm) for more detailed information.
7070

7171
### Features
7272
- Supported SFT Methods: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine-tuning)
@@ -180,7 +180,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
180180
model = Swift.from_pretrained(model, model_dir, inference_mode=True)
181181
template = get_template(template_type, tokenizer)
182182
query = 'xxxxxx'
183-
response, history = inference(model, template, query, verbose=False)
183+
response, history = inference(model, template, query)
184184
print(f'response: {response}')
185185
print(f'history: {history}')
186186
```
@@ -204,7 +204,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
204204

205205
template = get_template(template_type, tokenizer)
206206
query = 'xxxxxx'
207-
response, history = inference(model, template, query, verbose=False)
207+
response, history = inference(model, template, query)
208208
print(f'response: {response}')
209209
print(f'history: {history}')
210210
```

README_CN.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
6464

6565

6666
## ✨ 大模型微调的例子
67-
LLM微调的详细使用文档可以查看[这里](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm).
67+
用户可以查看[LLM微调文档](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm)来获得更详细的介绍.
6868

6969
### 特性
7070
- 支持的SFT方法: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调
@@ -177,7 +177,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
177177
model = Swift.from_pretrained(model, model_dir, inference_mode=True)
178178
template = get_template(template_type, tokenizer)
179179
query = 'xxxxxx'
180-
response, history = inference(model, template, query, verbose=False)
180+
response, history = inference(model, template, query)
181181
print(f'response: {response}')
182182
print(f'history: {history}')
183183
```
@@ -201,7 +201,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
201201

202202
template = get_template(template_type, tokenizer)
203203
query = 'xxxxxx'
204-
response, history = inference(model, template, query, verbose=False)
204+
response, history = inference(model, template, query)
205205
print(f'response: {response}')
206206
print(f'history: {history}')
207207
```

examples/pytorch/llm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
143143
model = Swift.from_pretrained(model, model_dir, inference_mode=True)
144144
template = get_template(template_type, tokenizer)
145145
query = 'xxxxxx'
146-
response, history = inference(model, template, query, verbose=False)
146+
response, history = inference(model, template, query)
147147
print(f'response: {response}')
148148
print(f'history: {history}')
149149
```
@@ -167,7 +167,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
167167

168168
template = get_template(template_type, tokenizer)
169169
query = 'xxxxxx'
170-
response, history = inference(model, template, query, verbose=False)
170+
response, history = inference(model, template, query)
171171
print(f'response: {response}')
172172
print(f'history: {history}')
173173
```

examples/pytorch/llm/README_CN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
142142
model = Swift.from_pretrained(model, model_dir, inference_mode=True)
143143
template = get_template(template_type, tokenizer)
144144
query = 'xxxxxx'
145-
response, history = inference(model, template, query, verbose=False)
145+
response, history = inference(model, template, query)
146146
print(f'response: {response}')
147147
print(f'history: {history}')
148148
```
@@ -166,7 +166,7 @@ model, tokenizer = get_model_tokenizer(model_type, torch.bfloat16, {'device_map'
166166

167167
template = get_template(template_type, tokenizer)
168168
query = 'xxxxxx'
169-
response, history = inference(model, template, query, verbose=False)
169+
response, history = inference(model, template, query)
170170
print(f'response: {response}')
171171
print(f'history: {history}')
172172
```

examples/pytorch/llm/scripts/tongyi_finance_14b_chat_int4/qlora/sft.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ python llm_sft.py \
1111
--dtype fp16 \
1212
--output_dir output \
1313
--custom_train_dataset_path xxx.jsonl \
14+
--custom_val_dataset_path yyy.jsonl \
1415
--train_dataset_sample -1 \
1516
--num_train_epochs 1 \
1617
--max_length 4096 \

swift/cli/merge_lora.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from swift.llm import InferArguments, merge_lora
2-
from swift.utils import parse_args
1+
from swift.llm.run import merge_lora_main
32

43
if __name__ == '__main__':
5-
args, remaining_argv = parse_args(InferArguments)
6-
merge_lora(args, replace_if_exists=True)
4+
merge_lora_main(replace_if_exists=True)

swift/llm/infer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
logger = get_logger()
1919

2020

21-
def merge_lora(args: InferArguments, replace_if_exists=False) -> None:
21+
def merge_lora(args: InferArguments, replace_if_exists=False) -> str:
2222
logger.info(f'replace_if_exists: {replace_if_exists}')
2323
assert args.ckpt_dir is not None
2424
assert args.sft_type == 'lora'
@@ -66,7 +66,7 @@ def merge_lora(args: InferArguments, replace_if_exists=False) -> None:
6666
res.pop('adapter_cfg', None)
6767
with open(new_configuration_path, 'w') as f:
6868
json.dump(res, f, ensure_ascii=False, indent=4)
69-
# sft_args
69+
# sft_args.json
7070
sft_args_fname = 'sft_args.json'
7171
old_sft_args_path = os.path.join(old_ckpt_dir, sft_args_fname)
7272
new_sft_args_path = os.path.join(args.ckpt_dir, sft_args_fname)
@@ -80,7 +80,9 @@ def merge_lora(args: InferArguments, replace_if_exists=False) -> None:
8080
else:
8181
logger.info(
8282
f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
83-
'skipping the saving process.')
83+
'skipping the saving process. '
84+
'you can pass `replace_if_exists=True` to overwrite it.')
85+
return merged_lora_path
8486

8587

8688
def prepare_model_template(
@@ -152,7 +154,8 @@ def llm_infer(args: InferArguments) -> None:
152154
if args.eval_human:
153155
while True:
154156
query = input('<<< ')
155-
_, history = inference(model, template, query, stream=args.stream)
157+
_, history = inference(
158+
model, template, query, stream=args.stream, verbose=True)
156159
item = history[0]
157160
if jsonl_path is not None:
158161
save_result_to_jsonl(jsonl_path, item[0], item[1])
@@ -175,7 +178,8 @@ def llm_infer(args: InferArguments) -> None:
175178
data.get('query'),
176179
data.get('history'),
177180
data.get('system'),
178-
stream=args.stream)
181+
stream=args.stream,
182+
verbose=True)
179183
label = data.get('response')
180184
item = history[0]
181185
if jsonl_path is not None:

swift/llm/rome.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def rome_infer(args: RomeArguments) -> None:
7575
if args.eval_human:
7676
while True:
7777
query = input('<<< ')
78-
inference(model, template, query, stream=args.stream)
78+
inference(model, template, query, stream=args.stream, verbose=True)
7979
else:
8080
_, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio,
8181
args.dataset_seed)
@@ -88,7 +88,8 @@ def rome_infer(args: RomeArguments) -> None:
8888
data.get('query'),
8989
data.get('history'),
9090
data.get('system'),
91-
stream=args.stream)
91+
stream=args.stream,
92+
verbose=True)
9293
print()
9394
print(f"[LABELS]{data.get('response')}")
9495
print('-' * 80)

swift/llm/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from swift.llm import (InferArguments, RomeArguments, SftArguments, get_main,
3-
llm_infer, llm_sft, llm_web_ui, rome_infer)
3+
llm_infer, llm_sft, llm_web_ui, merge_lora, rome_infer)
44

55
sft_main = get_main(SftArguments, llm_sft)
66
infer_main = get_main(InferArguments, llm_infer)
77
rome_main = get_main(RomeArguments, rome_infer)
88
web_ui_main = get_main(InferArguments, llm_web_ui)
9+
merge_lora_main = get_main(InferArguments, merge_lora)

swift/llm/sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def llm_sft(args: SftArguments) -> str:
267267
f,
268268
ensure_ascii=False,
269269
indent=2)
270-
res = trainer.train(training_args.resume_from_checkpoint)
270+
trainer.train(training_args.resume_from_checkpoint)
271271
logger.info(
272272
f'best_model_checkpoint: {trainer.state.best_model_checkpoint}')
273273

@@ -283,6 +283,6 @@ def llm_sft(args: SftArguments) -> str:
283283
return {
284284
'best_model_checkpoint': trainer.state.best_model_checkpoint,
285285
'best_metric': trainer.state.best_metric,
286-
'global_step': res.global_step,
286+
'global_step': trainer.state.global_step,
287287
'log_history': trainer.state.log_history,
288288
}

0 commit comments

Comments
 (0)