Skip to content

Commit 20aef70

Browse files
authored
add feat: only save model (#49)
1 parent b2064ea commit 20aef70

File tree

12 files changed

+106
-26
lines changed

12 files changed

+106
-26
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Key features:
3232
[code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm)
3333

3434
1. supported SFT methods: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine-tuning)
35-
2. supported models: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
35+
2. supported models: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, openbuddy-llama2-70b, polylm-13b
3636
3. supported features: quantization, ddp, model parallelism(device map), gradient checkpointing, gradient accumulation, pushing to modelscope hub, custom datasets, multimodal and agent SFT, mutli-round chat, ...
3737
4. supported datasets:
3838
1. NLP: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh, instruct-en, gpt4all-en

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
3030
[code link](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm)
3131

3232
1. 支持的SFT方法: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调
33-
2. 支持的模型: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
33+
2. 支持的模型: qwen-7b, [qwen-7b-chat](https://github.com/QwenLM/Qwen-7B), qwen-vl, [qwen-vl-chat](https://github.com/QwenLM/Qwen-VL), baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, openbuddy-llama2-70b, polylm-13b
3434
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpointing, 梯度累加, 支持推送ModelScope Hub, 自定义数据集, 多模态和Agent SFT, 多轮对话, ...
3535
4. 支持的数据集:
3636
1. NLP: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh, firefly-all-zh, poetry-zh, instruct-en, gpt4all-en

examples/pytorch/llm/scripts/qwen_7b_chat/full/sft.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ python src/llm_sft.py \
1818
--gradient_accumulation_steps 16 \
1919
--max_grad_norm 1 \
2020
--warmup_ratio 0.03 \
21-
--eval_steps 50 \
21+
--eval_steps 100 \
2222
--save_steps 100 \
23+
--only_save_model true \
2324
--save_total_limit 2 \
2425
--logging_steps 10 \
2526
--use_flash_attn false \

examples/pytorch/llm/src/llm_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def llm_infer(args: InferArguments) -> None:
109109
args.system,
110110
args.max_length,
111111
batched=False)
112-
streamer = TextStreamer(
113-
tokenizer, skip_prompt=True, skip_special_tokens=True)
112+
streamer = TextStreamer(tokenizer, skip_prompt=True)
114113
generation_config = GenerationConfig(
115114
max_new_tokens=args.max_new_tokens,
116115
temperature=args.temperature,
@@ -126,6 +125,7 @@ def llm_infer(args: InferArguments) -> None:
126125
query = input('<<< ')
127126
data = {'query': query}
128127
input_ids = preprocess_func(data)['input_ids']
128+
streamer.decode_kwargs['skip_special_tokens'] = True
129129
inference(input_ids, model, tokenizer, streamer, generation_config,
130130
args.skip_prompt)
131131
else:

examples/pytorch/llm/src/llm_sft.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class SftArguments:
8585

8686
eval_steps: int = 50
8787
save_steps: Optional[int] = None
88+
only_save_model: Optional[bool] = None
8889
save_total_limit: int = 2
8990
logging_steps: int = 5
9091
dataloader_num_workers: int = 1
@@ -126,23 +127,24 @@ def __post_init__(self):
126127
if self.sft_type == 'lora':
127128
if self.learning_rate is None:
128129
self.learning_rate = 1e-4
129-
if self.save_steps is None:
130-
self.save_steps = self.eval_steps
130+
if self.only_save_model is None:
131+
self.only_save_model = False
131132
elif self.sft_type == 'full':
132133
assert self.quantization_bit is None, 'not supported'
133134
assert self.dtype != 'fp16', 'please use bf16 or fp32'
134135
if self.learning_rate is None:
135136
self.learning_rate = 1e-5
136-
if self.save_steps is None:
137-
# Saving the model takes a long time
138-
self.save_steps = self.eval_steps * 4
137+
if self.only_save_model is None:
138+
self.only_save_model = True
139139
else:
140140
raise ValueError(f'sft_type: {self.sft_type}')
141+
141142
if self.template_type is None:
142143
self.template_type = MODEL_MAPPING[self.model_type].get(
143144
'template', 'default')
144145
logger.info(f'Setting template_type: {self.template_type}')
145-
146+
if self.save_steps is None:
147+
self.save_steps = self.eval_steps
146148
self.output_dir = os.path.join(self.output_dir, self.model_type)
147149

148150
if self.lora_target_modules is None:
@@ -288,7 +290,8 @@ def llm_sft(args: SftArguments) -> None:
288290
resume_from_checkpoint=args.resume_from_ckpt,
289291
ddp_backend=args.ddp_backend,
290292
gradient_checkpointing=args.gradient_checkpointing,
291-
local_rank=local_rank)
293+
local_rank=local_rank,
294+
only_save_model=args.only_save_model)
292295

293296
if args.gradient_checkpointing:
294297
# fix: gradients will be None

examples/pytorch/llm/src/utils/dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,14 @@ def _process_mutimodal_dataset(dataset: HfDataset, prompt: str, image_key: str,
137137
response_key: str) -> HfDataset:
138138
dataset._info.features._column_requires_decoding['image'] = False
139139
query_format = f'<img>{{image_path}}</img>{prompt}'
140-
query = [
141-
query_format.format(image_path=d[image_key]['path']) for d in dataset
142-
]
143-
dataset = HfDataset.from_dict({
144-
'query': query,
145-
'response': dataset[response_key]
146-
})
140+
query = []
141+
response = []
142+
for d in tqdm(dataset):
143+
query.append(query_format.format(image_path=d[image_key]['path']))
144+
if '&&' in d[response_key]:
145+
d[response_key] = d[response_key].split('&&')[0]
146+
response.append(d[response_key])
147+
dataset = HfDataset.from_dict({'query': query, 'response': response})
147148
return dataset
148149

149150

examples/pytorch/llm/src/utils/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,14 @@ class LoRATM(NamedTuple):
222222
},
223223
'chatglm2-6b': {
224224
'model_id': 'ZhipuAI/chatglm2-6b',
225-
'revision': 'v1.0.8',
225+
'revision': 'v1.0.9',
226226
'get_function': get_model_tokenizer_chatglm2,
227227
'template': 'chatglm2',
228228
'lora_TM': LoRATM.chatglm2,
229229
},
230230
'chatglm2-6b-32k': {
231231
'model_id': 'ZhipuAI/chatglm2-6b-32k',
232-
'revision': 'v1.0.0',
232+
'revision': 'v1.0.1',
233233
'template': 'chatglm2',
234234
'lora_TM': LoRATM.chatglm2,
235235
},

swift/trainers/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
HPSearchBackend, HubStrategy,
55
IntervalStrategy, SchedulerType,
66
ShardedDDPOption)
7-
from transformers.training_args import TrainingArguments
8-
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
97

8+
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
109
from .trainers import Seq2SeqTrainer, Trainer

swift/trainers/arguments.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
3+
from dataclasses import dataclass
4+
5+
from transformers.training_args import TrainingArguments as HfTrainingArguments
6+
from transformers.training_args_seq2seq import \
7+
Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments
8+
9+
10+
@dataclass
11+
class SwiftArgumentsMixin:
12+
# ckpt only save model
13+
only_save_model: bool = False
14+
15+
16+
@dataclass
17+
class TrainingArguments(SwiftArgumentsMixin, HfTrainingArguments):
18+
pass
19+
20+
21+
@dataclass
22+
class Seq2SeqTrainingArguments(SwiftArgumentsMixin,
23+
HfSeq2SeqTrainingArguments):
24+
pass

swift/trainers/mixin.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
2+
# Part of the implementation is borrowed from huggingface/transformers.
33
import os
44
import shutil
55
from types import MethodType
66
from typing import Callable, Dict, List, Optional, Tuple, Union
77

88
import json
9+
import numpy as np
910
import safetensors
1011
import torch
1112
from datasets import Dataset as HfDataset
@@ -15,6 +16,7 @@
1516
from transformers import PreTrainedModel, PreTrainedTokenizerBase
1617
from transformers.data.data_collator import DataCollator
1718
from transformers.modeling_utils import unwrap_model
19+
from transformers.trainer import PREFIX_CHECKPOINT_DIR, TRAINER_STATE_NAME
1820
from transformers.trainer_callback import TrainerCallback
1921
from transformers.trainer_utils import EvalPrediction, HubStrategy
2022
from transformers.training_args import TrainingArguments
@@ -278,3 +280,52 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
278280
if self.tokenizer is not None:
279281
self.tokenizer.save_pretrained(output_dir)
280282
torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
283+
284+
def _save_checkpoint(self, model, trial, metrics=None):
285+
only_save_model = getattr(self.args, 'only_save_model', False)
286+
if only_save_model:
287+
return self._only_save_model(model, trial, metrics)
288+
else:
289+
return super()._save_checkpoint(model, trial, metrics)
290+
291+
def _only_save_model(self, model, trial, metrics=None):
292+
# Save model checkpoint
293+
checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'
294+
295+
if self.hp_search_backend is None and trial is None:
296+
self.store_flos()
297+
298+
run_dir = self._get_output_dir(trial=trial)
299+
output_dir = os.path.join(run_dir, checkpoint_folder)
300+
self.save_model(output_dir, _internal_call=True)
301+
if self.is_deepspeed_enabled:
302+
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
303+
# config `stage3_gather_16bit_weights_on_model_save` is True
304+
self.model_wrapped.save_checkpoint(output_dir)
305+
306+
# Determine the new best metric / best model checkpoint
307+
if metrics is not None and self.args.metric_for_best_model is not None:
308+
metric_to_check = self.args.metric_for_best_model
309+
if not metric_to_check.startswith('eval_'):
310+
metric_to_check = f'eval_{metric_to_check}'
311+
metric_value = metrics[metric_to_check]
312+
313+
operator = np.greater if self.args.greater_is_better else np.less
314+
if (self.state.best_metric is None
315+
or self.state.best_model_checkpoint is None
316+
or operator(metric_value, self.state.best_metric)):
317+
self.state.best_metric = metric_value
318+
self.state.best_model_checkpoint = output_dir
319+
320+
# Save the Trainer state
321+
if self.args.should_save:
322+
self.state.save_to_json(
323+
os.path.join(output_dir, TRAINER_STATE_NAME))
324+
325+
# push to hub
326+
if self.args.push_to_hub:
327+
self._push_from_checkpoint(output_dir)
328+
329+
# Maybe delete some older checkpoints.
330+
if self.args.should_save:
331+
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

0 commit comments

Comments
 (0)