Skip to content

Commit 49a388d

Browse files
authored
add test for cls_postprocess (#12534)
1 parent a764c56 commit 49a388d

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

.github/workflows/tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ jobs:
2727
pip install -e .
2828
- name: Test with pytest
2929
run: |
30-
pytest tests/
30+
pytest --verbose tests/

tests/test_cls_postprocess.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import os
2+
import sys
3+
import pytest
4+
import paddle
5+
import numpy as np
6+
7+
current_dir = os.path.dirname(os.path.abspath(__file__))
8+
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
9+
10+
from ppocr.postprocess.cls_postprocess import ClsPostProcess
11+
12+
13+
# Fixtures for common test inputs
14+
@pytest.fixture
15+
def preds_tensor():
16+
return paddle.to_tensor(np.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]]))
17+
18+
19+
@pytest.fixture
20+
def label_list():
21+
return {0: "class0", 1: "class1", 2: "class2"}
22+
23+
24+
# Parameterize tests to cover multiple scenarios
25+
@pytest.mark.parametrize(
26+
"label_list, expected",
27+
[
28+
({0: "class0", 1: "class1", 2: "class2"}, [("class1", 0.7), ("class2", 0.4)]),
29+
(None, [(1, 0.7), (2, 0.4)]),
30+
],
31+
)
32+
def test_cls_post_process_with_and_without_label_list(
33+
preds_tensor, label_list, expected
34+
):
35+
post_process = ClsPostProcess(label_list=label_list)
36+
result = post_process(preds_tensor)
37+
assert isinstance(result, list), "Result should be a list"
38+
assert result == expected, f"Expected {expected}, got {result}"
39+
40+
41+
# Test with a key in the prediction dictionary
42+
def test_cls_post_process_with_key(preds_tensor, label_list):
43+
preds_dict = {"key": preds_tensor}
44+
post_process = ClsPostProcess(label_list=label_list, key="key")
45+
result = post_process(preds_dict)
46+
expected = [("class1", 0.7), ("class2", 0.4)]
47+
assert isinstance(result, list), "Result should be a list"
48+
assert result == expected, f"Expected {expected}, got {result}"
49+
50+
51+
# Test with label input
52+
def test_cls_post_process_with_label(preds_tensor, label_list):
53+
labels = [2, 0]
54+
post_process = ClsPostProcess(label_list=label_list)
55+
result, label_result = post_process(preds_tensor, labels)
56+
expected_result = [("class1", 0.7), ("class2", 0.4)]
57+
expected_label_result = [("class2", 1.0), ("class0", 1.0)]
58+
assert isinstance(result, list), "Result should be a list"
59+
assert result == expected_result, f"Expected {expected_result}, got {result}"
60+
assert isinstance(label_result, list), "Label result should be a list"
61+
assert (
62+
label_result == expected_label_result
63+
), f"Expected {expected_label_result}, got {label_result}"

0 commit comments

Comments
 (0)