Skip to content

Commit 9c12335

Browse files
committed
Added tests for threshold processor
1 parent 843dd00 commit 9c12335

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# inference/tests/test_threshold_processor.py
2+
# docker-compose -f local.yml run --rm django pytest inference/tests/test_threshold_processor.py
3+
4+
from unittest.mock import patch
5+
6+
import pytest
7+
8+
from inference.utils.threshold_processor import ClassificationThresholdProcessor
9+
10+
11+
class TestClassificationThresholdProcessor:
12+
"""Test suite for ClassificationThresholdProcessor class."""
13+
14+
@pytest.fixture
15+
def test_thresholds(self):
16+
"""Test thresholds for generic processor."""
17+
return {"TAG_A": 0.7, "TAG_B": 0.5, "TAG_C": 0.8}
18+
19+
@pytest.fixture
20+
def processor(self, test_thresholds):
21+
"""Create a ClassificationThresholdProcessor with test thresholds."""
22+
return ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)
23+
24+
def test_initialization(self, test_thresholds):
25+
"""Test initialization with provided thresholds."""
26+
processor = ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)
27+
assert processor.thresholds == test_thresholds
28+
assert processor.default_threshold == 0.6
29+
30+
def test_for_tdamm_factory(self):
31+
"""Test the for_tdamm class factory method."""
32+
with patch("inference.utils.threshold_processor.TDAMM_TAG_THRESHOLDS", {"TAG1": 0.7}), patch(
33+
"inference.utils.threshold_processor.DEFAULT_TDAMM_THRESHOLD", 0.5
34+
):
35+
processor = ClassificationThresholdProcessor.for_tdamm()
36+
assert processor.thresholds == {"TAG1": 0.7}
37+
assert processor.default_threshold == 0.5
38+
39+
def test_for_division_factory(self):
40+
"""Test the for_division class factory method."""
41+
with patch("inference.utils.threshold_processor.DIVISION_TAG_THRESHOLDS", {1: 0.7}), patch(
42+
"inference.utils.threshold_processor.DEFAULT_DIVISION_THRESHOLD", 0.5
43+
):
44+
processor = ClassificationThresholdProcessor.for_division()
45+
assert processor.thresholds == {1: 0.7}
46+
assert processor.default_threshold == 0.5
47+
48+
def test_get_threshold_exact_match(self, processor):
49+
"""Test get_threshold with an exact tag match."""
50+
assert processor.get_threshold("TAG_A") == 0.7
51+
assert processor.get_threshold("TAG_B") == 0.5
52+
assert processor.get_threshold("TAG_C") == 0.8
53+
54+
def test_get_threshold_no_match(self, processor):
55+
"""Test get_threshold with a tag that doesn't exist."""
56+
assert processor.get_threshold("UNKNOWN_TAG") == 0.6 # default threshold
57+
58+
def test_filter_classifications_all_pass(self, processor):
59+
"""Test filter_classifications where all pass their thresholds."""
60+
classifications = {
61+
"TAG_A": 0.8, # 0.8 > 0.7 threshold
62+
"TAG_B": 0.6, # 0.6 > 0.5 threshold
63+
}
64+
filtered = processor.filter_classifications(classifications)
65+
assert len(filtered) == 2
66+
assert "TAG_A" in filtered
67+
assert "TAG_B" in filtered
68+
69+
def test_filter_classifications_some_pass(self, processor):
70+
"""Test filter_classifications where some pass their thresholds."""
71+
classifications = {
72+
"TAG_A": 0.6, # 0.6 < 0.7 threshold
73+
"TAG_B": 0.6, # 0.6 > 0.5 threshold
74+
"TAG_C": 0.9, # 0.9 > 0.8 threshold
75+
}
76+
filtered = processor.filter_classifications(classifications)
77+
assert len(filtered) == 2
78+
assert "TAG_A" not in filtered
79+
assert "TAG_B" in filtered
80+
assert "TAG_C" in filtered
81+
82+
def test_filter_classifications_none_pass(self, processor):
83+
"""Test filter_classifications where none pass their thresholds."""
84+
classifications = {
85+
"TAG_A": 0.6, # 0.6 < 0.7 threshold
86+
"TAG_C": 0.7, # 0.7 < 0.8 threshold
87+
}
88+
filtered = processor.filter_classifications(classifications)
89+
assert len(filtered) == 0
90+
91+
def test_filter_classifications_default_threshold(self, processor):
92+
"""Test filter_classifications using default threshold for unknown tags."""
93+
classifications = {
94+
"UNKNOWN_TAG": 0.7, # 0.7 > 0.6 default threshold
95+
}
96+
filtered = processor.filter_classifications(classifications)
97+
assert len(filtered) == 1
98+
assert "UNKNOWN_TAG" in filtered
99+
100+
def test_filter_classifications_string_confidence(self, processor):
101+
"""Test filter_classifications with string confidence values."""
102+
classifications = {
103+
"TAG_A": "0.8", # Should convert to float and pass
104+
"TAG_B": "0.4", # Should convert to float and fail
105+
}
106+
filtered = processor.filter_classifications(classifications)
107+
assert len(filtered) == 1
108+
assert "TAG_A" in filtered
109+
assert "TAG_B" not in filtered
110+
111+
def test_filter_classifications_invalid_confidence(self, processor):
112+
"""Test filter_classifications with invalid confidence values."""
113+
classifications = {
114+
"TAG_A": 0.8,
115+
"TAG_B": "not a number", # Invalid
116+
"TAG_C": 0.9,
117+
}
118+
filtered = processor.filter_classifications(classifications)
119+
assert len(filtered) == 2
120+
assert "TAG_A" in filtered
121+
assert "TAG_B" not in filtered
122+
assert "TAG_C" in filtered
123+
124+
def test_filter_classifications_exact_threshold(self, processor):
125+
"""Test filter_classifications with confidence exactly at threshold."""
126+
classifications = {
127+
"TAG_A": 0.7, # Exactly at threshold (0.7)
128+
"TAG_B": 0.5, # Exactly at threshold (0.5)
129+
}
130+
filtered = processor.filter_classifications(classifications)
131+
assert len(filtered) == 2
132+
assert "TAG_A" in filtered
133+
assert "TAG_B" in filtered
134+
135+
def test_filter_classifications_empty_dict(self, processor):
136+
"""Test filter_classifications with an empty dictionary."""
137+
filtered = processor.filter_classifications({})
138+
assert filtered == {}

0 commit comments

Comments
 (0)