|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import pytest |
| 4 | +import paddle |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +current_dir = os.path.dirname(os.path.abspath(__file__)) |
| 8 | +sys.path.append(os.path.abspath(os.path.join(current_dir, ".."))) |
| 9 | + |
| 10 | +from ppocr.postprocess.cls_postprocess import ClsPostProcess |
| 11 | + |
| 12 | + |
| 13 | +# Fixtures for common test inputs |
| 14 | +@pytest.fixture |
| 15 | +def preds_tensor(): |
| 16 | + return paddle.to_tensor(np.array([[0.1, 0.7, 0.2], [0.3, 0.3, 0.4]])) |
| 17 | + |
| 18 | + |
| 19 | +@pytest.fixture |
| 20 | +def label_list(): |
| 21 | + return {0: "class0", 1: "class1", 2: "class2"} |
| 22 | + |
| 23 | + |
| 24 | +# Parameterize tests to cover multiple scenarios |
| 25 | +@pytest.mark.parametrize( |
| 26 | + "label_list, expected", |
| 27 | + [ |
| 28 | + ({0: "class0", 1: "class1", 2: "class2"}, [("class1", 0.7), ("class2", 0.4)]), |
| 29 | + (None, [(1, 0.7), (2, 0.4)]), |
| 30 | + ], |
| 31 | +) |
| 32 | +def test_cls_post_process_with_and_without_label_list( |
| 33 | + preds_tensor, label_list, expected |
| 34 | +): |
| 35 | + post_process = ClsPostProcess(label_list=label_list) |
| 36 | + result = post_process(preds_tensor) |
| 37 | + assert isinstance(result, list), "Result should be a list" |
| 38 | + assert result == expected, f"Expected {expected}, got {result}" |
| 39 | + |
| 40 | + |
| 41 | +# Test with a key in the prediction dictionary |
| 42 | +def test_cls_post_process_with_key(preds_tensor, label_list): |
| 43 | + preds_dict = {"key": preds_tensor} |
| 44 | + post_process = ClsPostProcess(label_list=label_list, key="key") |
| 45 | + result = post_process(preds_dict) |
| 46 | + expected = [("class1", 0.7), ("class2", 0.4)] |
| 47 | + assert isinstance(result, list), "Result should be a list" |
| 48 | + assert result == expected, f"Expected {expected}, got {result}" |
| 49 | + |
| 50 | + |
| 51 | +# Test with label input |
| 52 | +def test_cls_post_process_with_label(preds_tensor, label_list): |
| 53 | + labels = [2, 0] |
| 54 | + post_process = ClsPostProcess(label_list=label_list) |
| 55 | + result, label_result = post_process(preds_tensor, labels) |
| 56 | + expected_result = [("class1", 0.7), ("class2", 0.4)] |
| 57 | + expected_label_result = [("class2", 1.0), ("class0", 1.0)] |
| 58 | + assert isinstance(result, list), "Result should be a list" |
| 59 | + assert result == expected_result, f"Expected {expected_result}, got {result}" |
| 60 | + assert isinstance(label_result, list), "Label result should be a list" |
| 61 | + assert ( |
| 62 | + label_result == expected_label_result |
| 63 | + ), f"Expected {expected_label_result}, got {label_result}" |
0 commit comments