Skip to content

Commit d8765f2

Browse files
authored
add estimate training in dpo training (#2573)
1 parent c9e9b68 commit d8765f2

File tree

3 files changed

+220
-1
lines changed

3 files changed

+220
-1
lines changed

examples/alignment/dpo/dpo_argument.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
class DPOTrainingArguments(TrainingArguments):
3030
"""DPOTrainingArguments"""
3131

32+
num_of_gpus: int = field(
33+
default=-1,
34+
metadata={"help": "Number of gpus used in dpo estimate training."},
35+
)
3236
unified_checkpoint: bool = field(
3337
default=True,
3438
metadata={"help": "Enable fused linear grad add strategy."},
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
""" Estimate DPO """
16+
17+
import json
18+
import os
19+
20+
import numpy as np
21+
import paddle
22+
23+
from paddleformers.utils.log import logger
24+
25+
# isort: off
26+
# fmt: off
27+
# isort: on
28+
from paddleformers.datasets.dpo import create_dataset
29+
30+
31+
def calculate_acc_steps(num_samples, train_batch, dataset_world_size, per_device_train_batch_size):
32+
"""calculate_acc_steps
33+
34+
Args:
35+
num_samples (int): Total training samples in dataset
36+
train_batch (int): Target global batch size
37+
dataset_world_size (int): Number of dataset parallel training devices
38+
per_device_train_batch_size (int): Batch size per GPU/device
39+
40+
Returns:
41+
int: Number of gradient accumulation steps needed to achieve:
42+
- Global batch size target
43+
- Full dataset coverage
44+
"""
45+
samples_per_batch = per_device_train_batch_size * dataset_world_size * num_samples / train_batch
46+
if num_samples < 100:
47+
recommend_bs = 8
48+
elif num_samples < 1000:
49+
recommend_bs = 16
50+
elif num_samples < 10000:
51+
recommend_bs = 32
52+
elif num_samples < 100000:
53+
recommend_bs = 64
54+
else:
55+
recommend_bs = 128
56+
return min(np.ceil(recommend_bs / samples_per_batch), 32)
57+
58+
59+
def dpo_estimate_training(tokenizer, data_args, training_args, config, train_dataset=None):
60+
""" dpo_estimate_training
61+
62+
Args:
63+
tokenizer (PreTrainedTokenizer): Text tokenization
64+
data_args (DataArguments): Datasets configuration
65+
training_args (TrainingArguments): Training configuration
66+
config (PretrainedConfig): Model configuration
67+
train_dataset (Dataset, optional): Preloaded dataset
68+
69+
Returns:
70+
training_args (TrainingArguments): Training configuration with max_steps setting
71+
res (Dict): Training estimate results
72+
"""
73+
74+
if training_args.should_save or training_args.should_save_model_state:
75+
os.makedirs(training_args.output_dir, exist_ok=True)
76+
if train_dataset is None:
77+
dataset_config = {
78+
"tokenizer": tokenizer,
79+
"max_seq_len": data_args.max_seq_len,
80+
"max_prompt_len": data_args.max_prompt_len,
81+
"random_seed": training_args.seed,
82+
"num_replicas": 1,
83+
"rank": 0,
84+
"num_samples_each_epoch": data_args.num_samples_each_epoch,
85+
"random_shuffle": data_args.random_shuffle,
86+
"greedy_intokens": data_args.greedy_intokens,
87+
"buffer_size": data_args.buffer_size,
88+
"mask_out_eos_token": data_args.mask_out_eos_token,
89+
"packing": data_args.packing,
90+
"mix_strategy": data_args.mix_strategy,
91+
"encode_one_turn": data_args.encode_one_turn,
92+
}
93+
train_dataset = create_dataset(
94+
task_group=data_args.train_dataset_path,
95+
task_group_prob=data_args.train_dataset_prob,
96+
sub_dataset_type=data_args.train_dataset_type,
97+
**dataset_config
98+
)
99+
max_samples = len(train_dataset.mix_datasets)
100+
if max_samples > 0 :
101+
if training_args.num_of_gpus > 0:
102+
dataset_world_size = (
103+
training_args.num_of_gpus
104+
// max(1, training_args.tensor_parallel_degree)
105+
// max(1, training_args.pipeline_parallel_degree))
106+
if dataset_world_size < 1:
107+
raise ValueError("dataset_world_size must be positive, please verify your config")
108+
else:
109+
dataset_world_size = training_args.dataset_world_size
110+
111+
num_samples = 0
112+
train_tokens = 0
113+
train_batch = 0
114+
for sequences in train_dataset:
115+
if num_samples >= max_samples:
116+
break
117+
train_batch += 1
118+
for sequence in sequences:
119+
train_tokens += len(sequence.input_ids)
120+
num_samples += 1
121+
if training_args.gradient_accumulation_steps < 0:
122+
training_args.gradient_accumulation_steps = calculate_acc_steps(
123+
num_samples, train_batch, dataset_world_size, training_args.per_device_train_batch_size)
124+
max_samples *= training_args.num_train_epochs
125+
train_tokens *= training_args.num_train_epochs
126+
train_batch *= training_args.num_train_epochs
127+
global_batch_size = (
128+
training_args.per_device_train_batch_size
129+
* training_args.gradient_accumulation_steps
130+
* dataset_world_size
131+
)
132+
if training_args.num_of_gpus < 0:
133+
training_args.num_of_gpus = paddle.distributed.get_world_size()
134+
135+
training_args.max_steps = np.ceil(train_batch / global_batch_size)
136+
total_tokens = training_args.max_steps * data_args.max_seq_len * global_batch_size
137+
res = {
138+
"num_train_epochs": int(training_args.num_train_epochs),
139+
"max_steps": int(training_args.max_steps),
140+
"train_samples": int(max_samples),
141+
"gradient_accumulation_steps": int(training_args.gradient_accumulation_steps),
142+
"num_of_gpus": int(training_args.num_of_gpus),
143+
"per_device_train_batch_size": int(training_args.per_device_train_batch_size),
144+
"pipeline_parallel_degree": int(max(1, training_args.pipeline_parallel_degree)),
145+
"tensor_parallel_degree": int(max(1, training_args.tensor_parallel_degree)),
146+
"seed": int(training_args.seed),
147+
"num_samples_each_epoch": int(data_args.num_samples_each_epoch),
148+
"max_seq_len": int(data_args.max_seq_len),
149+
"max_prompt_len": int(data_args.max_prompt_len),
150+
"total_tokens": int(total_tokens),
151+
"train_tokens": int(train_tokens),
152+
"valid": True,
153+
}
154+
if train_batch / training_args.num_train_epochs / global_batch_size < 1:
155+
logger.warning("This dataset is too small, you'd better enlarge your dataset.")
156+
res["valid"] = False
157+
else:
158+
training_args.max_steps = 0
159+
logger.error("No valid data found, please check your dataset format.")
160+
res = {
161+
"num_train_epochs": int(training_args.num_train_epochs),
162+
"max_steps": int(training_args.max_steps),
163+
"train_samples": 0,
164+
"gradient_accumulation_steps": int(training_args.gradient_accumulation_steps),
165+
"num_of_gpus": int(training_args.num_of_gpus),
166+
"per_device_train_batch_size": int(training_args.per_device_train_batch_size),
167+
"pipeline_parallel_degree": int(max(1, training_args.pipeline_parallel_degree)),
168+
"tensor_parallel_degree": int(max(1, training_args.tensor_parallel_degree)),
169+
"seed": int(training_args.seed),
170+
"num_samples_each_epoch": 6000000,
171+
"max_seq_len": int(data_args.max_seq_len),
172+
"max_prompt_len": int(data_args.max_prompt_len),
173+
"valid": False,
174+
}
175+
176+
logger.info(f"training argument: {res}")
177+
# NOTE(gongenlei): if not int, broadcast will overflow
178+
training_args.max_steps = int(training_args.max_steps)
179+
with open(os.path.join(training_args.output_dir, "dpo_train_args.json"), "w", encoding="utf-8") as f:
180+
json.dump(res, f)
181+
return training_args, res

examples/alignment/dpo/run_dpo.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,17 @@
2525
DPOModelArgument,
2626
DPOTrainingArguments,
2727
)
28+
from dpo_estimate_training import dpo_estimate_training
2829

2930
from paddleformers.datasets.dpo import collate_fn, create_dataset
3031
from paddleformers.nn.attention import AttentionInterface
3132
from paddleformers.peft import LoRAConfig, LoRAModel
32-
from paddleformers.trainer import PdArgumentParser, get_last_checkpoint, set_seed
33+
from paddleformers.trainer import (
34+
IntervalStrategy,
35+
PdArgumentParser,
36+
get_last_checkpoint,
37+
set_seed,
38+
)
3339
from paddleformers.transformers import (
3440
AutoConfig,
3541
AutoModelForCausalLM,
@@ -226,6 +232,34 @@ def main():
226232
"mix_strategy": data_args.mix_strategy,
227233
"encode_one_turn": data_args.encode_one_turn,
228234
}
235+
if training_args.max_steps == -1:
236+
if data_args.mix_strategy == "random":
237+
raise ValueError(
238+
"When using 'random' mix_strategy, max_steps must be explicitly set (cannot be -1). "
239+
"Random mixing requires a fixed number of training steps to properly sample data."
240+
)
241+
if training_args.should_load_dataset and paddle.distributed.get_rank() == 0:
242+
training_args, _ = dpo_estimate_training(tokenizer, data_args, training_args, config=model.config)
243+
244+
if paddle.distributed.get_world_size() > 1:
245+
paddle.distributed.barrier()
246+
pd_max_steps = paddle.to_tensor([training_args.max_steps])
247+
paddle.distributed.broadcast(pd_max_steps, src=0)
248+
training_args.max_steps = int(pd_max_steps.item())
249+
logger.info(
250+
f"Re-setting training_args.max_steps to {training_args.max_steps} ({training_args.num_train_epochs})"
251+
)
252+
if training_args.max_steps <= 0:
253+
raise ValueError(f"Invalid max_steps: {training_args.max_steps}. Please check your dataset")
254+
if training_args.save_strategy == IntervalStrategy.EPOCH:
255+
training_args.save_strategy = IntervalStrategy.STEPS
256+
training_args.save_steps = int(training_args.max_steps / training_args.num_train_epochs)
257+
if training_args.evaluation_strategy == IntervalStrategy.EPOCH:
258+
training_args.evaluation_strategy = IntervalStrategy.STEPS
259+
training_args.eval_steps = int(training_args.max_steps / training_args.num_train_epochs)
260+
if training_args.logging_strategy == IntervalStrategy.EPOCH:
261+
training_args.logging_strategy = IntervalStrategy.STEPS
262+
training_args.logging_steps = int(training_args.max_steps / training_args.num_train_epochs)
229263
if training_args.do_train and training_args.should_load_dataset:
230264
train_dataset = create_dataset(
231265
task_group=data_args.train_dataset_path,

0 commit comments

Comments
 (0)