Skip to content

Commit 0031d69

Browse files
authored
add estimate max_steps (#2566)
1 parent d73a424 commit 0031d69

11 files changed

+213
-48
lines changed

examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,21 @@
66
"eval_dataset_path": "./data/dev.json",
77
"eval_dataset_prob": "1.0",
88
"eval_dataset_type": "erniekit",
9-
"packing": true,
10-
"mix_strategy": "random",
9+
"packing": false,
10+
"mix_strategy": "concat",
1111
"output_dir": "./checkpoints/ernie4_5_paddle_sft_ckpts",
1212
"max_seq_len": 8192,
1313
"per_device_train_batch_size": 1,
1414
"gradient_accumulation_steps": 4,
15-
"per_device_eval_batch_size": 8,
15+
"per_device_eval_batch_size": 1,
1616
"eval_accumulation_steps":16,
1717
"num_train_epochs": 1,
1818
"learning_rate": 3e-05,
1919
"warmup_steps": 10,
2020
"logging_steps": 1,
21-
"max_steps": 100,
21+
"max_steps": -1,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,

examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,21 @@
66
"eval_dataset_path": "./data/dev.json",
77
"eval_dataset_prob": "1.0",
88
"eval_dataset_type": "erniekit",
9-
"packing": true,
10-
"mix_strategy": "random",
9+
"packing": false,
10+
"mix_strategy": "concat",
1111
"output_dir": "./checkpoints/ernie4_5_paddle_sft_ckpts",
1212
"max_seq_len": 8192,
1313
"per_device_train_batch_size": 1,
1414
"gradient_accumulation_steps": 4,
15-
"per_device_eval_batch_size": 8,
15+
"per_device_eval_batch_size": 1,
1616
"eval_accumulation_steps":16,
1717
"num_train_epochs": 1,
1818
"learning_rate": 3e-05,
1919
"warmup_steps": 10,
2020
"logging_steps": 1,
21-
"max_steps": 100,
21+
"max_steps": -1,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,

examples/config/gpt_oss/sft_argument_gptoss_20b.json

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,31 @@
11
{
22
"model_name_or_path": "../gpt-oss-model-bf16",
3-
"dataset_name_or_path": "./data",
3+
"train_dataset_path": "./data/train.json",
4+
"train_dataset_prob": "1.0",
5+
"train_dataset_type": "erniekit",
6+
"eval_dataset_path": "./data/dev.json",
7+
"eval_dataset_prob": "1.0",
8+
"eval_dataset_type": "erniekit",
9+
"packing": false,
10+
"mix_strategy": "concat",
411
"output_dir": "./checkpoints/gptoss_paddle_sft_ckpts",
12+
"max_seq_len": 8192,
513
"overwrite_output_dir": false,
614
"per_device_train_batch_size": 1,
715
"gradient_accumulation_steps": 4,
8-
"per_device_eval_batch_size": 8,
16+
"per_device_eval_batch_size": 1,
917
"eval_accumulation_steps":16,
1018
"num_train_epochs": 1,
1119
"learning_rate": 3e-05,
1220
"warmup_steps": 10,
1321
"logging_steps": 1,
22+
"max_steps": -1,
1423
"evaluation_strategy": "epoch",
1524
"save_strategy": "epoch",
16-
"src_length": 1024,
17-
"max_length": 2048,
1825
"bf16": true,
1926
"fp16_opt_level": "O2",
2027
"do_train": true,
21-
"do_eval": false,
28+
"do_eval": true,
2229
"disable_tqdm": true,
2330
"load_best_model_at_end": true,
2431
"eval_with_do_generation": false,

examples/config/qwen/dpo_argument_qwen2_0p5b.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77
"eval_dataset_path": "./data/dpo/dev.jsonl",
88
"eval_dataset_prob": "1.0",
99
"eval_dataset_type": "erniekit",
10-
"packing": true,
11-
"mix_strategy": "random",
10+
"packing": false,
11+
"mix_strategy": "concat",
1212
"output_dir": "./checkpoints/qwen2_paddle_dpo_ckpts",
13+
"max_seq_len": 8192,
1314
"per_device_train_batch_size": 1,
1415
"gradient_accumulation_steps": 8,
1516
"per_device_eval_batch_size": 1,
1617
"num_train_epochs": 1,
17-
"max_steps": 100,
18+
"max_steps": -1,
1819
"learning_rate": 1e-06,
1920
"warmup_steps": 10,
2021
"logging_steps": 1,
2122
"evaluation_strategy": "steps",
2223
"save_strategy": "steps",
2324
"eval_steps": 100,
2425
"save_steps": 500,
25-
"max_seq_len": 2048,
2626
"max_prompt_len": 1024,
2727
"bf16": true,
2828
"fp16_opt_level": "O2",

examples/config/qwen/dpo_lora_argument_qwen2_0p5b.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77
"eval_dataset_path": "./data/dpo/dev.jsonl",
88
"eval_dataset_prob": "1.0",
99
"eval_dataset_type": "erniekit",
10-
"packing": true,
11-
"mix_strategy": "random",
10+
"packing": false,
11+
"mix_strategy": "concat",
1212
"output_dir": "./checkpoints/qwen2_paddle_dpo_lora_ckpts",
13+
"max_seq_len": 8192,
1314
"per_device_train_batch_size": 1,
1415
"gradient_accumulation_steps": 8,
1516
"per_device_eval_batch_size": 1,
1617
"num_train_epochs": 1,
17-
"max_steps": 100,
18+
"max_steps": -1,
1819
"learning_rate": 1e-05,
1920
"warmup_steps": 10,
2021
"logging_steps": 1,
2122
"evaluation_strategy": "steps",
2223
"save_strategy": "steps",
2324
"eval_steps": 100,
2425
"save_steps": 500,
25-
"max_seq_len": 4096,
2626
"max_prompt_len": 2048,
2727
"bf16": true,
2828
"fp16_opt_level": "O2",

examples/config/qwen/lora_argument_qwen2_0p5b.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,21 @@
66
"eval_dataset_path": "./data/sft/dev.json",
77
"eval_dataset_prob": "1.0",
88
"eval_dataset_type": "erniekit",
9-
"packing": true,
10-
"mix_strategy": "random",
9+
"packing": false,
10+
"mix_strategy": "concat",
1111
"output_dir": "./checkpoints/qwen2_paddle_lora_ckpts",
1212
"max_seq_len": 8192,
1313
"per_device_train_batch_size": 1,
1414
"gradient_accumulation_steps": 4,
15-
"per_device_eval_batch_size": 8,
15+
"per_device_eval_batch_size": 1,
1616
"eval_accumulation_steps":16,
1717
"num_train_epochs": 1,
1818
"learning_rate": 3e-04,
1919
"warmup_steps": 30,
2020
"logging_steps": 1,
21-
"max_steps": 100,
21+
"max_steps": -1,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,

examples/config/qwen/sft_argument_qwen2_0p5b.json

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,21 @@
66
"eval_dataset_path": "./data/sft/dev.json",
77
"eval_dataset_prob": "1.0",
88
"eval_dataset_type": "erniekit",
9-
"packing": true,
10-
"mix_strategy": "random",
9+
"packing": false,
10+
"mix_strategy": "concat",
1111
"output_dir": "./checkpoints/qwen2_paddle_sft_ckpts",
1212
"max_seq_len": 8192,
1313
"per_device_train_batch_size": 1,
1414
"gradient_accumulation_steps": 4,
15-
"per_device_eval_batch_size": 8,
15+
"per_device_eval_batch_size": 1,
1616
"eval_accumulation_steps":16,
1717
"num_train_epochs": 1,
1818
"learning_rate": 3e-05,
1919
"warmup_steps": 10,
2020
"logging_steps": 1,
21-
"max_steps": 100,
21+
"max_steps": -1,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,

examples/run_finetune.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import gc
1516
import os
1617
import sys
1718
from functools import partial
1819

1920
import paddle
2021

22+
from paddleformers.datasets.data_utils import estimate_training
2123
from paddleformers.datasets.finetuning import collate_fn
2224
from paddleformers.datasets.finetuning import create_dataset as create_dataset_sft
2325
from paddleformers.peft import LoRAConfig, LoRAModel
24-
from paddleformers.trainer import PdArgumentParser, get_last_checkpoint, set_seed
26+
from paddleformers.trainer import (
27+
IntervalStrategy,
28+
PdArgumentParser,
29+
get_last_checkpoint,
30+
set_seed,
31+
)
2532
from paddleformers.transformers import (
2633
AutoConfig,
2734
AutoModelForCausalLM,
@@ -155,7 +162,7 @@ def main():
155162
if model_args.fuse_attention_ffn is not None:
156163
model_config.fuse_attention_ffn = model_args.fuse_attention_ffn
157164
model_config.pp_seg_method = training_args.pp_seg_method
158-
model_config.seq_length = data_args.max_length
165+
model_config.seq_length = training_args.max_seq_len
159166
model_config.max_sequence_length = training_args.max_seq_len
160167
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
161168
logger.info(f"Final model config: {model_config}")
@@ -262,6 +269,47 @@ def neft_post_hook(module, input, output):
262269
model_args=model_args,
263270
max_seq_len=training_args.max_seq_len + model_config.num_nextn_predict_layers,
264271
)
272+
273+
if training_args.max_steps == -1:
274+
if data_args.mix_strategy == "random":
275+
raise ValueError(
276+
"When using 'random' mix_strategy, max_steps must be explicitly set (cannot be -1). "
277+
"Random mixing requires a fixed number of training steps to properly sample data."
278+
)
279+
if paddle.distributed.get_rank() == 0:
280+
training_args.max_steps = estimate_training(train_dataset, data_args, training_args, model_args)
281+
del train_dataset
282+
gc.collect()
283+
train_dataset = create_dataset_sft(
284+
task_group=data_args.train_dataset_path,
285+
task_group_prob=data_args.train_dataset_prob,
286+
sub_dataset_type=data_args.train_dataset_type,
287+
**dataset_config,
288+
)
289+
290+
if paddle.distributed.get_world_size() > 1:
291+
paddle.distributed.barrier()
292+
max_steps = paddle.to_tensor([training_args.max_steps])
293+
paddle.distributed.broadcast(max_steps, src=0)
294+
training_args.max_steps = int(max_steps.item())
295+
if training_args.max_steps <= 0:
296+
raise ValueError(f"Invalid max_steps: {training_args.max_steps}. Please check your dataset")
297+
298+
logger.info(f"Re-setting training_args.max_steps to {training_args.max_steps}.")
299+
# Create the learning_rate sheduler and optimizer
300+
if training_args.decay_steps is None:
301+
training_args.decay_steps = training_args.max_steps
302+
303+
if training_args.save_strategy == IntervalStrategy.EPOCH:
304+
training_args.save_strategy = IntervalStrategy.STEPS
305+
training_args.save_steps = int(training_args.max_steps / training_args.num_train_epochs)
306+
if training_args.evaluation_strategy == IntervalStrategy.EPOCH:
307+
training_args.evaluation_strategy = IntervalStrategy.STEPS
308+
training_args.eval_steps = int(training_args.max_steps / training_args.num_train_epochs)
309+
if training_args.logging_strategy == IntervalStrategy.EPOCH:
310+
training_args.logging_strategy = IntervalStrategy.STEPS
311+
training_args.logging_steps = int(training_args.max_steps / training_args.num_train_epochs)
312+
265313
trainer = SFTTrainer(
266314
model=model,
267315
args=training_args,

0 commit comments

Comments
 (0)