Skip to content

Commit e20a65a

Browse files
committed
fix ppo zero3 (#4263)
1 parent 88094e0 commit e20a65a

File tree

6 files changed

+89
-4
lines changed

6 files changed

+89
-4
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# 4 * 50GiB
2+
nproc_per_node=4
3+
4+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
5+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
6+
NPROC_PER_NODE=$nproc_per_node \
7+
MAX_PIXELS=1003520 \
8+
swift rlhf \
9+
--rlhf_type dpo \
10+
--model Qwen/Qwen2.5-VL-7B-Instruct \
11+
--dataset 'swift/RLAIF-V-Dataset#20000' \
12+
--train_type full \
13+
--torch_dtype bfloat16 \
14+
--num_train_epochs 1 \
15+
--per_device_train_batch_size 1 \
16+
--per_device_eval_batch_size 1 \
17+
--learning_rate 1e-5 \
18+
--freeze_vit true \
19+
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
20+
--eval_steps 100 \
21+
--save_steps 100 \
22+
--save_total_limit 2 \
23+
--deepspeed zero3 \
24+
--logging_steps 5 \
25+
--max_length 4096 \
26+
--output_dir output \
27+
--warmup_ratio 0.05 \
28+
--dataloader_num_workers 4 \
29+
--dataset_num_proc 4 \
30+
--save_only_model true

examples/train/multimodal/rlhf/dpo.sh renamed to examples/train/multimodal/rlhf/dpo/lora.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 4*50GiB
1+
# 4 * 50GiB
22
# You can refer to `https://github.com/QwenLM/Qwen2.5-VL` for the meaning of the `MAX_PIXELS` parameter.
33
# --rlhf_type cpo/orpo/simpo/rm are also supported
44
nproc_per_node=2

examples/train/rlhf/ppo/full.sh

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# 8 * 65 GiB
2+
# Currently, it only supports the case where the model and reward_model use the same template/tokenizer.
3+
# Currently, multimodal model PPO is not supported.
4+
nproc_per_node=8
5+
6+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
7+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
8+
NPROC_PER_NODE=$nproc_per_node \
9+
swift rlhf \
10+
--rlhf_type ppo \
11+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
12+
--reward_model 'AI-ModelScope/Skywork-Reward-Llama-3.1-8B-v0.2' \
13+
--train_type full \
14+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#20000' 'AI-ModelScope/alpaca-gpt4-data-en#20000' \
15+
--torch_dtype bfloat16 \
16+
--num_train_epochs 1 \
17+
--per_device_train_batch_size 1 \
18+
--per_device_eval_batch_size 1 \
19+
--learning_rate 1e-6 \
20+
--gradient_accumulation_steps $(expr 16 / $nproc_per_node) \
21+
--eval_steps 100 \
22+
--save_steps 100 \
23+
--save_total_limit 2 \
24+
--logging_steps 5 \
25+
--max_length 2048 \
26+
--output_dir output \
27+
--warmup_ratio 0.05 \
28+
--dataloader_num_workers 4 \
29+
--deepspeed zero3 \
30+
--response_length 512 \
31+
--temperature 0.7 \
32+
--dataset_num_proc 4 \
33+
--save_only_model true

examples/train/rlhf/ppo.sh renamed to examples/train/rlhf/ppo/lora.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
# 4 * 50GiB
12
# Currently, it only supports the case where the model and reward_model use the same template/tokenizer.
23
# Currently, multimodal model PPO is not supported.
34
nproc_per_node=4
45

6+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
57
CUDA_VISIBLE_DEVICES=0,1,2,3 \
68
NPROC_PER_NODE=$nproc_per_node \
79
swift rlhf \
@@ -30,4 +32,5 @@ swift rlhf \
3032
--deepspeed zero2 \
3133
--response_length 512 \
3234
--temperature 0.7 \
33-
--dataset_num_proc 4
35+
--dataset_num_proc 4 \
36+
--save_only_model true

swift/trainers/mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _load_optimizer_and_scheduler(self, *args, **kwargs):
170170
def _save_model(self, output_dir: Optional[str] = None, state_dict=None):
171171
# model
172172
supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
173-
supported_names = ('SentenceTransformer')
173+
supported_names = ('SentenceTransformer', )
174174
if AutoModelForCausalLMWithValueHead is not None:
175175
supported_classes = supported_classes + (AutoModelForCausalLMWithValueHead, )
176176
save_safetensors = self.args.save_safetensors

swift/trainers/rlhf_trainer/ppo_trainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import inspect
33
from contextlib import contextmanager
4+
from typing import Optional
45

56
import transformers
67
from packaging import version
78
from torch.utils.data import DataLoader
8-
from transformers import PreTrainedModel
9+
from transformers import PreTrainedModel, Trainer
910
from trl import PPOTrainer as HFPPOTrainer
1011

1112
from swift.utils import patch_getattr
@@ -63,3 +64,21 @@ def _save_checkpoint(self, *args, **kwargs):
6364
trial = kwargs.get('trial')
6465
self._determine_best_metric(metrics=metrics, trial=trial)
6566
return super()._save_checkpoint(*args, **kwargs)
67+
68+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
69+
# https://github.com/huggingface/trl/issues/2122
70+
backup_model = self.model
71+
self.model = self.model.policy # save only the policy
72+
73+
Trainer.save_model(self, output_dir, _internal_call)
74+
75+
self.model = backup_model
76+
77+
def _save(self, output_dir: Optional[str] = None, state_dict=None):
78+
if self.is_deepspeed_enabled:
79+
state_dict = {
80+
name.removeprefix('policy.'): param
81+
for name, param in state_dict.items() if name.startswith('policy.')
82+
}
83+
84+
super()._save(output_dir, state_dict)

0 commit comments

Comments
 (0)