Skip to content

Commit 9f77043

Browse files
committed
add xpu case
1 parent c34d360 commit 9f77043

File tree

3 files changed

+127
-0
lines changed

3 files changed

+127
-0
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"1": 11.43915844,
3+
"2": 10.98821735,
4+
"3": 10.11469746,
5+
"4": 9.73008347,
6+
"5": 8.18760586,
7+
"6": 8.02382469,
8+
"7": 7.94480753,
9+
"8": 7.78190613,
10+
"9": 7.66679621,
11+
"10": 7.5971694
12+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
### data
2+
train_dataset_type: messages
3+
train_dataset_path: tests/fixtures/dummy/sft-vl/thinking_safety_demo.jsonl
4+
train_dataset_prob: "1.0"
5+
max_seq_len: 32768
6+
packing: true
7+
mix_strategy: concat
8+
template_backend: custom
9+
template: ernie_vl
10+
random_shuffle: false
11+
dataloader_num_workers: 4
12+
13+
### model
14+
model_name_or_path: baidu/ERNIE-4.5-VL-28B-A3B-Thinking
15+
attn_impl: flashmask
16+
num_hidden_layers: 4
17+
18+
### finetuning
19+
# base
20+
stage: VL-SFT
21+
fine_tuning: full
22+
seed: 23
23+
do_train: true
24+
do_eval: false
25+
per_device_train_batch_size: 1
26+
num_train_epochs: 1
27+
max_steps: 10
28+
save_steps: 10000
29+
save_strategy: steps
30+
logging_steps: 1
31+
gradient_accumulation_steps: 1
32+
logging_dir: ./vdl_log
33+
output_dir: ./checkpoints/ernie-vl-thinking-sft-full
34+
disable_tqdm: true
35+
36+
# train
37+
warmup_steps: 0
38+
learning_rate: 1.0e-5
39+
40+
# performance
41+
tensor_model_parallel_size: 2
42+
pipeline_model_parallel_size: 2
43+
sharding: stage1
44+
use_sparse_head_and_loss_fn: true
45+
bf16: true
46+
fp16_opt_level: O2
47+
save_checkpoint_format: "flex_checkpoint"
48+
load_checkpoint_format: "flex_checkpoint"
49+
freeze_config: freeze_vision
50+
51+
# recompute
52+
recompute: true
53+
recompute_granularity: full
54+
recompute_method: uniform
55+
recompute_num_layers: 1
56+
recompute_modules: ["loss_fn"]
57+
recompute_use_reentrant: true
58+
59+
use_flash_attention: true
60+
sequence_parallel: true
61+
pp_seg_method: layer:Ernie4_5_DecoderLayer|ErnieDecoderLayer|EmptyLayer
62+
offload_queue: true
63+
pp_delay_scale_loss: true
64+
overlap_p2p_comm: true
65+
best_unbalanced_scheduler: true
66+
sharding_comm_buffer_size_MB: 2048
67+
save_sharding_stage1_model_include_freeze_params: true
68+
offload_optim: false
69+
tensorwise_offload_optimizer: false
70+
unified_checkpoint_config: ignore_merge_optimizer
71+
pre_alloc_memory: 60
72+
amp_master_grad: 1
73+
74+
device: xpu
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from conftest import run_command_and_validate
17+
18+
19+
def test_ernie_28b_thinking_sft_training(project_root, base_value_dir, log_file):
20+
"""Test ERNIE-28B-thinking SFT training loss values.
21+
22+
This test runs the following shell command:
23+
paddleformers-cli train scripts/xpu_ci/config/ernie_vl_28b_sft.yaml
24+
25+
Then validates that loss values match the baseline within tolerance of 1e-6.
26+
"""
27+
# Define the exact shell command to execute
28+
cmd = "paddleformers-cli train scripts/xpu_ci/config/ernie_vl_28b_sft.yaml"
29+
30+
# Execute command and validate results
31+
passed, error_msg = run_command_and_validate(
32+
cmd=cmd,
33+
baseline_path=base_value_dir / "ernie_21b_sft_loss.json",
34+
log_file=log_file,
35+
working_dir=project_root,
36+
tolerance=1e-6,
37+
timeout=3600,
38+
)
39+
40+
if not passed:
41+
pytest.fail(error_msg)

0 commit comments

Comments
 (0)