Skip to content

Commit 9240897

Browse files
authored
fix not saving the last ckpt bug (#29)
1 parent 55eec46 commit 9240897

File tree

11 files changed

+132
-35
lines changed

11 files changed

+132
-35
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,5 @@ result.mp4
131131

132132
# ast template
133133
ast_index_file.py
134+
135+
runs/

examples/pytorch/llm/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<p align="center">
55
<img src="https://img.shields.io/badge/python-%E2%89%A53.8-5be.svg">
66
<img src="https://img.shields.io/badge/pytorch-%E2%89%A51.12%20%7C%20%E2%89%A52.0-orange.svg">
7-
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.8.1-5D91D4.svg"></a>
7+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.8.4-5D91D4.svg"></a>
88
<a href="https://github.com/modelscope/swift/"><img src="https://img.shields.io/badge/ms--swift-%E2%89%A51.0.0-6FEBB9.svg"></a>
99
</p>
1010

@@ -16,10 +16,10 @@
1616

1717
## Features
1818
1. supported sft method: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine tuning), ...
19-
2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), qwen-7b-chat, baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ...
19+
2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), qwen-7b-chat, baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
2020
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, ...
21-
4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ...
22-
5. supported templates: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default, ...
21+
4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh
22+
5. supported templates: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
2323

2424
## Prepare the Environment
2525
Experimental environment: A10, 3090, A100, ... (V100 does not support bf16, quantization)

examples/pytorch/llm/README_CN.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<p align="center">
55
<img src="https://img.shields.io/badge/python-%E2%89%A53.8-5be.svg">
66
<img src="https://img.shields.io/badge/pytorch-%E2%89%A51.12%20%7C%20%E2%89%A52.0-orange.svg">
7-
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.8.1-5D91D4.svg"></a>
7+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.8.4-5D91D4.svg"></a>
88
<a href="https://github.com/modelscope/swift/"><img src="https://img.shields.io/badge/ms--swift-%E2%89%A51.0.0-6FEBB9.svg">
99
</p>
1010

@@ -17,10 +17,10 @@
1717

1818
## 特性
1919
1. [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调, ...
20-
2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), qwen-7b-chat, baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ...
20+
2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), qwen-7b-chat, baichuan-7b, baichuan-13b, baichuan-13b-chat, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-7b-chat, llama2-13b, llama2-13b-chat, llama2-70b, llama2-70b-chat, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b
2121
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 支持自定义数据集, ...
22-
4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ...
23-
5. 支持的template: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default, ...
22+
4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, cot-en, cot-zh
23+
5. 支持的template: chatml(qwen), baichuan, chatglm2, llama, openbuddy_llama, default
2424

2525
## 准备实验环境
2626
实验环境: A10, 3090, A100均可. (V100不支持bf16, 量化)

examples/pytorch/llm/src/llm_infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
@dataclass
1919
class InferArguments:
2020
model_type: str = field(
21-
default='qwen-7b-chat', metadata={'choices': list(MODEL_MAPPING.keys())})
21+
default='qwen-7b-chat',
22+
metadata={'choices': list(MODEL_MAPPING.keys())})
2223
sft_type: str = field(
2324
default='lora', metadata={'choices': ['lora', 'full']})
2425
template_type: str = field(

examples/pytorch/llm/src/llm_sft.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class SftArguments:
6969
lora_alpha: int = 32
7070
lora_dropout_p: float = 0.1
7171

72-
gradient_checkpoint: bool = True
72+
gradient_checkpointing: bool = True
7373
batch_size: int = 1
7474
num_train_epochs: int = 1
7575
optim: str = 'adamw_torch'
@@ -84,6 +84,7 @@ class SftArguments:
8484
save_steps: Optional[int] = None
8585
save_total_limit: int = 2
8686
logging_steps: int = 5
87+
dataloader_num_workers: int = 1
8788

8889
push_to_hub: bool = False
8990
# 'user_name/repo_name' or 'repo_name'
@@ -263,7 +264,7 @@ def llm_sft(args: SftArguments) -> None:
263264
bf16=args.bf16,
264265
fp16=args.fp16,
265266
eval_steps=args.eval_steps,
266-
dataloader_num_workers=1,
267+
dataloader_num_workers=args.dataloader_num_workers,
267268
load_best_model_at_end=True,
268269
metric_for_best_model='loss',
269270
greater_is_better=False,
@@ -276,11 +277,12 @@ def llm_sft(args: SftArguments) -> None:
276277
push_to_hub=args.push_to_hub,
277278
resume_from_checkpoint=args.resume_from_ckpt,
278279
ddp_backend=args.ddp_backend,
279-
gradient_checkpointing=args.gradient_checkpoint,
280+
gradient_checkpointing=args.gradient_checkpointing,
280281
local_rank=local_rank)
281282

282-
if args.gradient_checkpoint:
283+
if args.gradient_checkpointing:
283284
# fix: gradients will be None
285+
model.config.use_cache = False
284286
model.enable_input_require_grads()
285287
if is_dist():
286288
trainer_args.ddp_find_unused_parameters = False

examples/pytorch/llm/src/utils/dataset.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,18 @@ def get_instinwild_en_dataset():
110110
return _processing_alpaca(dataset)
111111

112112

113+
def get_cot_en_dataset() -> HfDataset:
114+
dataset: HfDataset = MsDataset.load(
115+
'YorickHe/CoT', split='train').to_hf_dataset()
116+
return _processing_alpaca(dataset)
117+
118+
119+
def get_cot_zh_dataset() -> HfDataset:
120+
dataset: HfDataset = MsDataset.load(
121+
'YorickHe/CoT_zh', split='train').to_hf_dataset()
122+
return _processing_alpaca(dataset)
123+
124+
113125
DATASET_MAPPING = {
114126
'alpaca-en': get_alpaca_gpt4_en_dataset,
115127
'alpaca-zh': get_alpaca_gpt4_zh_dataset,
@@ -120,8 +132,10 @@ def get_instinwild_en_dataset():
120132
for k in _multi_alpaca_language_list
121133
},
122134
'code-en': get_code_alpaca_en_dataset,
123-
'instinwild-zh': get_instinwild_zh_dataset,
124135
'instinwild-en': get_instinwild_en_dataset,
136+
'instinwild-zh': get_instinwild_zh_dataset,
137+
'cot-en': get_cot_en_dataset,
138+
'cot-zh': get_cot_zh_dataset,
125139
}
126140

127141

examples/pytorch/llm/src/utils/model.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import NamedTuple, Optional
55

66
import torch
7-
from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model,
8-
read_config, snapshot_download)
7+
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
8+
AutoTokenizer, Model, read_config, snapshot_download)
99
from torch import dtype as Dtype
1010

1111
from swift import get_logger
@@ -18,6 +18,7 @@ def get_model_tokenizer_from_repo(model_dir: str,
1818
load_model: bool = True,
1919
model_config=None,
2020
tokenizer=None,
21+
automodel_class=AutoModelForCausalLM,
2122
**model_kwargs):
2223
"""load from an independent repository"""
2324
if model_config is None:
@@ -30,7 +31,7 @@ def get_model_tokenizer_from_repo(model_dir: str,
3031
model_dir, trust_remote_code=True)
3132
model = None
3233
if load_model:
33-
model = AutoModelForCausalLM.from_pretrained(
34+
model = automodel_class.from_pretrained(
3435
model_dir,
3536
config=model_config,
3637
torch_dtype=torch_dtype,
@@ -88,8 +89,12 @@ def get_model_tokenizer_chatglm2(model_dir: str,
8889
model_kwargs['quantization_config'].llm_int8_skip_modules = [
8990
'output_layer'
9091
]
91-
return get_model_tokenizer_from_repo(model_dir, torch_dtype, load_model,
92-
**model_kwargs)
92+
return get_model_tokenizer_from_repo(
93+
model_dir,
94+
torch_dtype,
95+
load_model,
96+
automodel_class=AutoModel,
97+
**model_kwargs)
9398

9499

95100
def get_model_tokenizer_llama2(model_dir: str,

examples/pytorch/llm/src/utils/preprocess.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151

5252

5353
def simplify_context_list(context_list: List[Context]) -> List[Context]:
54-
res = []
55-
temp = []
54+
res: List[Context] = []
55+
temp: List[str] = []
5656
for c in context_list:
5757
if isinstance(c, str):
5858
temp.append(c)
@@ -89,7 +89,7 @@ def concat_context_list(
8989

9090
def _encode(tokenizer: PreTrainedTokenizer, context_list: List[Context],
9191
placeholder_list: List[str]) -> List[int]:
92-
input_ids = []
92+
input_ids: List[int] = []
9393
placeholder_it = iter(placeholder_list)
9494
for context in context_list:
9595
if isinstance(context, list):
@@ -126,8 +126,8 @@ def _preprocess(
126126
template_config = TEMPLATE_MAPPING[template_type]
127127
if system is None:
128128
system = DEFAULT_SYSTEM
129-
total_context_list = []
130-
placeholder_list = []
129+
total_context_list: List[Context] = []
130+
placeholder_list: List[str] = []
131131
concat_context_list(
132132
template_config['prefix'],
133133
total_context_list,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import os
2+
3+
import json
4+
from tqdm import tqdm
5+
from transformers.trainer_callback import (DefaultFlowCallback,
6+
ProgressCallback, TrainerControl,
7+
TrainerState)
8+
from transformers.trainer_utils import has_length
9+
10+
from swift.trainers import TrainingArguments
11+
12+
13+
class ProgressCallbackNew(ProgressCallback):
14+
15+
def on_train_begin(self, args, state, control, **kwargs):
16+
if state.is_local_process_zero:
17+
self.training_bar = tqdm(total=state.max_steps, dynamic_ncols=True)
18+
self.current_step = 0
19+
20+
def on_prediction_step(self,
21+
args,
22+
state: TrainerState,
23+
control,
24+
eval_dataloader=None,
25+
**kwargs):
26+
if state.is_local_process_zero and has_length(eval_dataloader):
27+
if self.prediction_bar is None:
28+
self.training_bar.refresh()
29+
self.training_bar.fp.write('\n')
30+
self.prediction_bar = tqdm(
31+
total=len(eval_dataloader),
32+
leave=True,
33+
dynamic_ncols=True,
34+
position=0)
35+
self.prediction_bar.update()
36+
37+
def on_log(self,
38+
args: TrainingArguments,
39+
state: TrainerState,
40+
control,
41+
logs=None,
42+
**kwargs):
43+
logs['global_step'] = state.global_step
44+
if 'learning_rate' in logs:
45+
logs['learning_rate'] = round(logs['learning_rate'], 8)
46+
if state.is_local_process_zero and self.training_bar is not None:
47+
jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
48+
with open(jsonl_path, 'a') as f:
49+
f.write(json.dumps(logs) + '\n')
50+
super().on_log(args, state, control, logs, **kwargs)
51+
52+
53+
class DefaultFlowCallbackNew(DefaultFlowCallback):
54+
55+
def on_step_end(self, args: TrainingArguments, state: TrainerState,
56+
control: TrainerControl, **kwargs):
57+
control = super().on_step_end(args, state, control, **kwargs)
58+
# save the last ckpt
59+
if state.global_step == state.max_steps:
60+
control.should_evaluate = True
61+
control.should_save = True
62+
return control

examples/pytorch/llm/src/utils/utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,39 @@
1+
import logging
12
import os
23
from typing import List, Optional, Tuple
34

45
import matplotlib.pyplot as plt
56
import torch
67
import torch.distributed as dist
8+
from modelscope.utils.logger import get_logger as get_ms_logger
79
from torch import dtype as Dtype
810
from torch.nn import Linear, Module
9-
from transformers import GenerationConfig, TextStreamer
11+
from transformers import GenerationConfig, TextStreamer, trainer
1012

1113
from swift import get_logger
14+
from swift.utils import is_master
1215
from swift.utils.tb_utils import (TB_COLOR, TB_COLOR_SMOOTH,
1316
read_tensorboard_file, tensorboard_smoothing)
17+
from .trainer_patch import DefaultFlowCallbackNew, ProgressCallbackNew
18+
19+
# monkey patch
20+
trainer.DEFAULT_PROGRESS_CALLBACK = ProgressCallbackNew
21+
trainer.DEFAULT_CALLBACKS = [DefaultFlowCallbackNew]
1422

15-
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
1623
logger = get_logger()
24+
ms_logger = get_ms_logger()
25+
26+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
27+
logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s')
28+
29+
logger.handlers[0].setFormatter(logger_format)
30+
ms_logger.handlers[0].setFormatter(logger_format)
31+
if is_master():
32+
logger.setLevel(logging.INFO)
33+
ms_logger.setLevel(logging.INFO)
34+
else:
35+
logger.setLevel(logging.ERROR)
36+
ms_logger.setLevel(logging.ERROR)
1737

1838
DTYPE_MAPPING = {
1939
'fp16': torch.float16,

0 commit comments

Comments
 (0)