Skip to content

Commit 3e9d351

Browse files
dddd-dDesmonDay
andauthored
Update ppo (#10912)
* add ppo * Update run_rl.py * Update ppo_trainer.py * Update config_utils.py * Update run_rl.py * Update gsm8k_processor.py * Update run_rl.py * Update score_model_utils.py * Update score_model_utils.py * Update run_rl.py * Update ppo_trainer.py * Update ppo_trainer.py * Update ppo_trainer.py * Update score_model_utils.py * Update gsm8k_processor.py * pre_commit * Update advantage.py * Update advantage.py * pre-commit check --------- Co-authored-by: DesmonDay <[email protected]>
1 parent 2ae914e commit 3e9d351

File tree

15 files changed

+634
-303
lines changed

15 files changed

+634
-303
lines changed

llm/alignment/rl/gsm8k_processor.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Preprocess the GSM8k dataset to parquet format
16+
"""
17+
18+
import argparse
19+
import os
20+
import re
21+
22+
import datasets
23+
24+
25+
def extract_solution(solution_str):
26+
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
27+
assert solution is not None
28+
final_solution = solution.group(0)
29+
final_solution = final_solution.split("#### ")[1].replace(",", "")
30+
return final_solution
31+
32+
33+
if __name__ == "__main__":
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument("--local_dir", default="./gsm8k")
36+
37+
args = parser.parse_args()
38+
39+
data_source = "openai/gsm8k"
40+
41+
dataset = datasets.load_dataset(data_source, "main")
42+
43+
train_dataset = dataset["train"]
44+
test_dataset = dataset["test"]
45+
46+
instruction_following = 'Let\'s think step by step and output the final answer after "####".'
47+
48+
# add a row to each data item that represents a unique id
49+
def make_map_fn(split):
50+
def process_fn(example, idx):
51+
question_raw = "<|im_start|>user\n" + example.pop("question")
52+
53+
system_raw = (
54+
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
55+
)
56+
question = system_raw + question_raw + " " + instruction_following + "<|im_end|>\n<|im_start|>assistant\n"
57+
58+
answer_raw = example.pop("answer")
59+
solution = extract_solution(answer_raw)
60+
data = {
61+
"src": question,
62+
"tgt": solution,
63+
}
64+
return data
65+
66+
return process_fn
67+
68+
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
69+
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
70+
71+
local_dir = args.local_dir
72+
73+
train_dataset.to_json(os.path.join(local_dir, "train.jsonl"), orient="records", lines=True)
74+
test_dataset.to_json(os.path.join(local_dir, "test.jsonl"), orient="records", lines=True)

llm/alignment/rl/run_rl.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from paddlenlp.transformers import (
4343
AutoConfig,
4444
AutoModelForCausalLM,
45+
AutoModelForTokenClassification,
4546
AutoTokenizer,
4647
PretrainedConfig,
4748
)
@@ -134,7 +135,6 @@ def create_actor_models(
134135
)
135136
if not training_args.autotuner_benchmark:
136137
reference_model.set_state_dict(actor_model.state_dict())
137-
138138
actor_tokenizer = AutoTokenizer.from_pretrained(
139139
model_args.actor_model_name_or_path,
140140
model_max_length=data_args.max_length,
@@ -210,46 +210,43 @@ def create_critic_models(
210210
data_args: DataArgument,
211211
training_args: TrainingArguments,
212212
common_config: Dict,
213-
reward_model,
214213
):
215214
with timers_scope_runtimer("Critic model loading time"):
216-
reward_model_config = reward_model.config
217-
if model_args.critic_model_name_or_path is None:
218-
model_args.critic_model_name_or_path = model_args.reward_model_name_or_path
219-
critic_model = AutoModelForScore.from_config(
220-
reward_model_config,
221-
dtype=training_args.model_dtype,
222-
score_type="critic",
223-
do_normalize=False,
224-
clip_range_value=training_args.clip_range_value,
225-
**common_config,
215+
critic_model_config = AutoConfig.from_pretrained(
216+
model_args.critic_model_name_or_path,
217+
tensor_parallel_output=training_args.tensor_parallel_output,
218+
tensor_parallel_degree=training_args.tensor_parallel_degree,
219+
tensor_parallel_rank=training_args.tensor_parallel_rank,
220+
dtype=training_args.model_dtype,
221+
recompute=training_args.critic_recompute,
222+
recompute_granularity=model_args.critic_recompute_granularity,
223+
recompute_use_reentrant=training_args.recompute_use_reentrant,
224+
**common_config,
225+
)
226+
LlmMetaConfig.set_llm_config(critic_model_config, training_args)
227+
228+
critic_model_config.max_position_embeddings = data_args.max_length
229+
critic_model_config.use_sparse_head_and_loss_fn = False
230+
critic_model_config.num_labels = 1
231+
critic_model_config.classifier_dropout = 0.0
232+
critic_model_config.hidden_dropout = 0.0
233+
logger.info(f"Loading Critic model with config:\n\t{critic_model_config}\n")
234+
235+
if not training_args.autotuner_benchmark:
236+
critic_model = AutoModelForTokenClassification.from_pretrained(
237+
model_args.critic_model_name_or_path,
238+
config=critic_model_config,
226239
)
227-
if not training_args.autotuner_benchmark:
228-
critic_model.set_state_dict(reward_model.state_dict())
229240
else:
230-
if not training_args.autotuner_benchmark:
231-
critic_model = AutoModelForScore.from_pretrained(
232-
model_args.critic_model_name_or_path,
233-
config=reward_model_config,
234-
score_type="critic",
235-
do_normalize=False,
236-
clip_range_value=training_args.clip_range_value,
237-
**common_config,
238-
)
239-
else:
240-
critic_model = AutoModelForScore.from_config(
241-
reward_model_config,
242-
score_type="critic",
243-
do_normalize=False,
244-
clip_range_value=training_args.clip_range_value,
245-
**common_config,
246-
)
241+
critic_model = AutoModelForTokenClassification.from_config(
242+
critic_model_config,
243+
)
247244

248245
critic_tokenizer = AutoTokenizer.from_pretrained(
249246
model_args.critic_model_name_or_path,
250247
model_max_length=data_args.max_length,
251248
padding_side="left",
252-
tokenizer_alpha=model_args.reward_critic_tokenizer_alpha,
249+
tokenizer_alpha=model_args.critic_tokenizer_alpha,
253250
use_fast=True,
254251
)
255252
if critic_tokenizer.pad_token_id is None:
@@ -261,16 +258,16 @@ def create_critic_models(
261258
if training_args.eval_mode == "single":
262259
config.tensor_parallel_degree = -1
263260
config.tensor_parallel_rank = 0
264-
with timers_scope_runtimer("Reward critic eval model loading time"):
265-
critic_eval_model = AutoModelForScore.from_config(config)
261+
with timers_scope_runtimer("Critic eval model loading time"):
262+
critic_eval_model = AutoModelForTokenClassification.from_config(config)
266263
else:
267264
critic_eval_model = None
268265

269266
return critic_model, critic_eval_model, critic_tokenizer
270267

271268

272269
def create_rl_dataset(data_args, training_args, tokenizer):
273-
requires_label = True if training_args.use_rm_server else False
270+
requires_label = True if training_args.use_rm_server or training_args.use_rule_reward else False
274271
train_ds = RLHFDataset(
275272
dataset_name_or_path=data_args.train_datasets,
276273
tokenizer=tokenizer,
@@ -333,15 +330,16 @@ def main():
333330
actor_model, actor_eval_model, reference_model, actor_tokenizer = create_actor_models(
334331
model_args, data_args, training_args, common_config, reshard_controller
335332
)
336-
337-
if not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
333+
if training_args.use_rule_reward:
334+
reward_model, reward_tokenizer = None, actor_tokenizer
335+
elif not training_args.use_rm_server and model_args.reward_model_name_or_path is not None:
338336
reward_model, reward_tokenizer = create_reward_models(model_args, data_args, training_args, common_config)
339337
else:
340338
reward_model, reward_tokenizer = model_args.reward_server, actor_tokenizer
341339

342340
if training_args.rl_algorithm == "ppo":
343341
critic_model, critic_eval_model, critic_tokenizer = create_critic_models(
344-
model_args, data_args, training_args, common_config, reward_model
342+
model_args, data_args, training_args, common_config
345343
)
346344
else:
347345
critic_model, critic_eval_model, critic_tokenizer = None, None, None
@@ -355,15 +353,23 @@ def main():
355353
offload_tensor_to_cpu((reference_model, "freeze_model"))
356354

357355
if training_args.rl_algorithm == "ppo":
358-
offload_tensor_to_cpu((reward_model, "freeze_model"))
356+
if not training_args.use_rm_server and not training_args.use_rule_reward:
357+
offload_tensor_to_cpu((reward_model, "freeze_model"))
359358
if critic_eval_model is not None:
360359
offload_tensor_to_cpu((critic_eval_model, "freeze_model"))
361360

362361
# NOTE(gongenlei): release memory_reserved_size to equal to memory_allocated_size
363362
paddle.device.cuda.empty_cache()
364363

365364
def compute_metrics(eval_preds):
366-
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
365+
'''
366+
If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result).
367+
If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1.
368+
'''
369+
if training_args.use_rule_reward:
370+
accuracy = (eval_preds.predictions == 1).astype("float32").mean().item()
371+
else:
372+
accuracy = (eval_preds.predictions == 3).astype("float32").mean().item()
367373
return {"accuracy": accuracy}
368374

369375
try:
@@ -389,7 +395,7 @@ def compute_metrics(eval_preds):
389395
data_collator=partial(
390396
collate_fn,
391397
pad_token_id=actor_tokenizer.pad_token_id,
392-
requires_label=True if training_args.use_rm_server else False,
398+
requires_label=True if training_args.use_rm_server or training_args.use_rule_reward else False,
393399
max_prompt_len=data_args.max_prompt_len if training_args.balance_batch else None,
394400
), # NOTE: enforce prompt padding to max_prompt_len when using balance_batch
395401
compute_metrics=compute_metrics, # TODO: only used for grpo (kk datasets)

llm/config/qwen/ppo_argument.yaml

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# RL algorithms
2+
rl_algorithm: "ppo" # The reinforcement learning algorithm used, supported: "ppo", "grpo", "reinforce_plus_plus"
3+
4+
# models
5+
actor_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the actor model
6+
reward_model_name_or_path: "" # The name or path of the reward model
7+
critic_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the critic model
8+
use_rm_server: false # Whether to use the reward model server
9+
reward_server: "http://127.0.0.1:8731" # The address of the reward model server
10+
use_rule_reward: True # The reward for gsm8k dataset. If use_rule_reward: use_rm_server = false
11+
12+
# logging
13+
logging_dir: ppo-logs # Directory for logging
14+
logging_steps: 1 # Number of steps between logging
15+
output_dir: "qwen2.5-1.5b-gsm8k-ppo/checkpoints" # Directory for output ckpts
16+
report_to: "wandb" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
17+
wandb_http_proxy: "http://agent.baidu.com:8188" # HTTP proxy for wandb
18+
run_name: "qwen2.5-1.5b-gsm8k-ppo" # Name of the run
19+
20+
# data
21+
train_datasets: "gsm8k/train.jsonl" # Path to the training dataset
22+
eval_datasets: "gsm8k/test.jsonl" # Path to the evaluation dataset
23+
prompt_key: "src" # Key for the prompt in the dataset
24+
response_key: "tgt" # Key for the response in the dataset
25+
dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader
26+
balance_batch: true # Whether to balance batch size across dataset_world_size
27+
use_remove_padding: true # Whether to remove padding tokens in the input
28+
29+
# distributed training args
30+
tensor_parallel_degree: 2 # Degree of tensor parallelism
31+
sequence_parallel: true # Whether to enable sequence parallelism
32+
sharding_parallel_degree: -1 # Degree of sharding parallelism
33+
sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2"
34+
sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism
35+
pipeline_parallel_degree: 1 # Degree of pipeline parallelism
36+
virtual_pp_degree: 1 # Degree of virtual pipeline parallelism
37+
38+
# rollout args
39+
max_prompt_len: 1024 # Maximum length of the prompt, exceeding which will be automatically truncated
40+
max_dec_len: 512 # Maximum length of the response
41+
min_dec_len: 32 # Minimum length of the response
42+
top_p: 1.0 # Top-p sampling parameter
43+
temperature: 1.0 # Temperature parameter for sampling
44+
repetition_penalty: 1.0 # Repetition penalty parameter
45+
rollout_max_num_seqs: 1024 # The maximum number of sequences that can be processed in a single inference
46+
rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8"
47+
48+
# training args
49+
do_train: true # Whether to perform training
50+
seed: 42 # Random seed for reproducibility
51+
global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size)
52+
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
53+
global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
54+
rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo'
55+
update_iters: 1 # Number of training iterations for rollout samples
56+
per_device_logprob_batch_size: 4 # Log probability batch size per device
57+
per_device_reward_batch_size: 2 # Reward batch size per device
58+
per_device_value_batch_size: 2 # Value batch size per device
59+
per_device_train_batch_size: 2 # Training micro batch size per device
60+
# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated): global_bz * rollout_n *
61+
num_train_epochs: 5 # Number of training epochs
62+
max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
63+
adam_beta1: 0.9 # AdamW optimizer beta1
64+
adam_beta2: 0.999 # AdamW optimizer beta2
65+
adam_epsilon: 1e-8 # AdamW optimizer epsilon
66+
max_grad_norm: 1.0 # Maximum gradient norm for clipping
67+
max_steps: -1 # Maximum number of training steps
68+
save_steps: 300 # Number of steps between model saves
69+
save_strategy: "steps" # Strategy for saving models
70+
ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified)
71+
disable_tqdm: true # Whether to disable tqdm progress bar
72+
73+
# actor training args
74+
learning_rate: 1e-6 # Learning rate for training
75+
min_learning_rate: 1e-6 # Minimum learning rate
76+
lr_scheduler_type: "constant" # Learning rate scheduler type
77+
weight_decay: 1e-2 # Weight decay for the AdamW optimizer
78+
warmup_ratio: 0.0 # Number of warmup steps
79+
80+
# critic training args
81+
critic_learning_rate: 1e-5 # Learning rate for critic model
82+
critic_min_learning_rate: 1e-5 # Minimum learning rate for critic model
83+
critic_lr_scheduler_type: "constant" # Learning rate scheduler type for critic model
84+
critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model
85+
critic_warmup_ratio: 0.0 # Number of warmup steps for critic model
86+
87+
# RL args
88+
kl_coeff: 0.0 # KL coefficient
89+
kl_loss_coeff: 0.001 # KL loss coefficient
90+
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
91+
entropy_coeff: 0.001 # Entropy coefficient
92+
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
93+
clip_range_ratio_low: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
94+
clip_range_ratio_high: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
95+
clip_range_score: 10.0 # The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score].
96+
enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer
97+
overlong_reward_buffer: 256 # The length of the overlong reward buffer
98+
overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer
99+
clip_range_value: 0.5 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value].
100+
normalize_reward: false # Whether to normalize reward
101+
normalize_advantage: false # Whether to normalize advantage
102+
dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476
103+
max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling
104+
use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss
105+
106+
# eval args
107+
do_eval: true # Whether to perform evaluation
108+
per_device_eval_batch_size: 1319 # Evaluation batch size per device
109+
evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps"
110+
eval_steps: 10 # Number of steps between evaluations
111+
112+
# device memory optimization args
113+
use_flash_attention: true # Whether to use fused attention operations
114+
use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
115+
use_fused_rope: false # Whether to use fused rope operations
116+
use_fused_head_and_loss_fn: true # Whether to use fused head and loss function
117+
use_fused_linear: true # Whether to use fused linear operations. 像是一个没有用的参数
118+
recompute: false # Whether to enable gradient checkpointing for memory optimization
119+
recompute_use_reentrant: false # Whether to use reentrant recompute
120+
recompute_granularity: "full" # Granularity of recompute
121+
bf16: true # Whether to use mixed precision with bfloat16
122+
fp16_opt_level: "O2" # Optimization level for fp16 and bf16 training
123+
amp_master_grad: false # Whether to use float32 weight gradients for master weights in amp opt level=’O2’
124+
amp_custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos"] # Custom black list for amp
125+
amp_custom_white_list: ["lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue"] # Custom white list for amp
126+
offload_level: "freeze_model" # Level of model offloading to pinned memory, supported values: freeze_model, train_model, optimizer
127+
release_grads: true # Whether to release gradients
128+
offload_optim: false # Whether to offload optimizer to pinned memory
129+
130+
# benchmark args
131+
skip_profile_timer: false # Whether to skip profiling time

0 commit comments

Comments
 (0)