Skip to content

Commit 8860942

Browse files
authored
Add PaddleOCR-VL CI case (#1386)
1 parent 88469f5 commit 8860942

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

tests/gpu/test_ocr.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import re
17+
import shutil
18+
import subprocess
19+
import tempfile
20+
import urllib.request
21+
22+
23+
import allure
24+
import yaml
25+
26+
OUTPUT_DIR = "./output/"
27+
LOG_DIR = "./erniekit_dist_log/"
28+
MODEL_PATH = "./PaddleOCR-VL/"
29+
CONFIG_PATH = "./examples/configs/PaddleOCR-VL/"
30+
SFT_CONFIG_PATH = CONFIG_PATH + "sft/"
31+
PORT = 8188
32+
33+
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
34+
os.environ["NCCL_ALGO"] = "Tree"
35+
os.environ["FLAGS_embedding_deterministic"] = "1"
36+
os.environ["FLAGS_cudnn_deterministic"] = "1"
37+
38+
39+
def clean_output_dir():
40+
if os.path.exists(OUTPUT_DIR):
41+
shutil.rmtree(OUTPUT_DIR)
42+
if os.path.exists(LOG_DIR):
43+
shutil.rmtree(LOG_DIR)
44+
45+
46+
def prepare_data():
47+
data_dir = "examples/data"
48+
os.makedirs(data_dir, exist_ok=True)
49+
50+
files = {
51+
"ocr_vl_sft-train_Bengali.jsonl": "https://paddleformers.bj.bcebos.com/datasets/ocr_vl_sft-train_Bengali.jsonl",
52+
"ocr_vl_sft-test_Bengali.jsonl": "https://paddleformers.bj.bcebos.com/datasets/ocr_vl_sft-test_Bengali.jsonl",
53+
}
54+
55+
for filename, url in files.items():
56+
file_path = os.path.join(data_dir, filename)
57+
58+
if not os.path.exists(file_path):
59+
print(f"Downloading {filename} ...")
60+
try:
61+
urllib.request.urlretrieve(url, file_path)
62+
print(f"Saved to {file_path}")
63+
except Exception as e:
64+
print(f"Failed to download {filename}: {e}")
65+
else:
66+
print(f"{filename} already exists, skip downloading.")
67+
68+
69+
def default_args(yaml_path):
70+
with open(yaml_path, "r", encoding="utf-8") as f:
71+
return yaml.safe_load(f)
72+
73+
74+
def run_update_config_training(config, steps="train"):
75+
with tempfile.NamedTemporaryFile(
76+
mode="w+", suffix=".yaml", delete=False
77+
) as temp_config:
78+
yaml.dump(config, temp_config)
79+
temp_config_path = temp_config.name
80+
cmd = [
81+
"erniekit",
82+
steps,
83+
temp_config_path,
84+
]
85+
if steps == "export":
86+
cmd.append("lora=True")
87+
88+
if steps == "server":
89+
process = subprocess.Popen(
90+
cmd,
91+
stdout=subprocess.PIPE,
92+
stderr=subprocess.STDOUT,
93+
text=True,
94+
preexec_fn=os.setsid,
95+
)
96+
return process
97+
elif steps == "chat":
98+
process = subprocess.Popen(
99+
cmd,
100+
stdout=subprocess.PIPE,
101+
stderr=subprocess.STDOUT,
102+
stdin=subprocess.PIPE,
103+
text=True,
104+
bufsize=1,
105+
)
106+
return process
107+
else:
108+
result = subprocess.run(
109+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
110+
)
111+
return result.returncode, result.stdout
112+
113+
114+
def assert_result(ret_code, log_output):
115+
"""assert result"""
116+
if ret_code != 0:
117+
print("\n".join(log_output.strip().splitlines()[-30:]))
118+
raise AssertionError("Training Failed")
119+
120+
121+
def assert_loss(base_loss):
122+
"""
123+
Calculate the average loss from the log file, and compare it with the expected value.
124+
"""
125+
log_path = os.path.join(os.getcwd(), "erniekit_dist_log", "workerlog.0")
126+
loss_pattern = re.compile(r"- loss:\s*([0-9]+\.[0-9]+)")
127+
with open(log_path, encoding="utf-8") as f:
128+
content = f.read()
129+
losses = [float(m.group(1)) for m in loss_pattern.finditer(content)]
130+
131+
if losses:
132+
sum_loss = sum(losses) / len(losses)
133+
avg_loss = round(sum_loss, 6)
134+
else:
135+
avg_loss = 0
136+
137+
assert (
138+
abs(avg_loss - base_loss) <= 0.0001
139+
), f"loss: {avg_loss}, base_loss: {base_loss}, exist diff!"
140+
141+
142+
def attach_log_file():
143+
log_path = os.path.join(os.getcwd(), "erniekit_dist_log", "workerlog.0")
144+
if os.path.exists(log_path):
145+
allure.attach.file(
146+
log_path, name="Trainning Log", attachment_type=allure.attachment_type.TEXT
147+
)
148+
else:
149+
allure.attach(
150+
f"Log file was not generated: {log_path}",
151+
name="Log Missing",
152+
attachment_type=allure.attachment_type.TEXT,
153+
)
154+
155+
156+
def test_sft():
157+
clean_output_dir()
158+
prepare_data()
159+
yaml_path = os.path.join(SFT_CONFIG_PATH, "run_ocr_vl_sft_16k.yaml")
160+
config = default_args(yaml_path).copy()
161+
config["max_steps"] = 3
162+
config["save_steps"] = 2
163+
config["model_name_or_path"] = MODEL_PATH
164+
config["output_dir"] = OUTPUT_DIR
165+
166+
ret_code, err_log = run_update_config_training(config)
167+
attach_log_file()
168+
assert_result(ret_code, err_log)
169+
170+
base_loss = 5.402314
171+
assert_loss(base_loss)

0 commit comments

Comments
 (0)