|
| 1 | +# inference/tests/test_threshold_processor.py |
| 2 | +# docker-compose -f local.yml run --rm django pytest inference/tests/test_threshold_processor.py |
| 3 | + |
| 4 | +from unittest.mock import patch |
| 5 | + |
| 6 | +import pytest |
| 7 | + |
| 8 | +from inference.utils.threshold_processor import ClassificationThresholdProcessor |
| 9 | + |
| 10 | + |
| 11 | +class TestClassificationThresholdProcessor: |
| 12 | + """Test suite for ClassificationThresholdProcessor class.""" |
| 13 | + |
| 14 | + @pytest.fixture |
| 15 | + def test_thresholds(self): |
| 16 | + """Test thresholds for generic processor.""" |
| 17 | + return {"TAG_A": 0.7, "TAG_B": 0.5, "TAG_C": 0.8} |
| 18 | + |
| 19 | + @pytest.fixture |
| 20 | + def processor(self, test_thresholds): |
| 21 | + """Create a ClassificationThresholdProcessor with test thresholds.""" |
| 22 | + return ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6) |
| 23 | + |
| 24 | + def test_initialization(self, test_thresholds): |
| 25 | + """Test initialization with provided thresholds.""" |
| 26 | + processor = ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6) |
| 27 | + assert processor.thresholds == test_thresholds |
| 28 | + assert processor.default_threshold == 0.6 |
| 29 | + |
| 30 | + def test_for_tdamm_factory(self): |
| 31 | + """Test the for_tdamm class factory method.""" |
| 32 | + with patch("inference.utils.threshold_processor.TDAMM_TAG_THRESHOLDS", {"TAG1": 0.7}), patch( |
| 33 | + "inference.utils.threshold_processor.DEFAULT_TDAMM_THRESHOLD", 0.5 |
| 34 | + ): |
| 35 | + processor = ClassificationThresholdProcessor.for_tdamm() |
| 36 | + assert processor.thresholds == {"TAG1": 0.7} |
| 37 | + assert processor.default_threshold == 0.5 |
| 38 | + |
| 39 | + def test_for_division_factory(self): |
| 40 | + """Test the for_division class factory method.""" |
| 41 | + with patch("inference.utils.threshold_processor.DIVISION_TAG_THRESHOLDS", {1: 0.7}), patch( |
| 42 | + "inference.utils.threshold_processor.DEFAULT_DIVISION_THRESHOLD", 0.5 |
| 43 | + ): |
| 44 | + processor = ClassificationThresholdProcessor.for_division() |
| 45 | + assert processor.thresholds == {1: 0.7} |
| 46 | + assert processor.default_threshold == 0.5 |
| 47 | + |
| 48 | + def test_get_threshold_exact_match(self, processor): |
| 49 | + """Test get_threshold with an exact tag match.""" |
| 50 | + assert processor.get_threshold("TAG_A") == 0.7 |
| 51 | + assert processor.get_threshold("TAG_B") == 0.5 |
| 52 | + assert processor.get_threshold("TAG_C") == 0.8 |
| 53 | + |
| 54 | + def test_get_threshold_no_match(self, processor): |
| 55 | + """Test get_threshold with a tag that doesn't exist.""" |
| 56 | + assert processor.get_threshold("UNKNOWN_TAG") == 0.6 # default threshold |
| 57 | + |
| 58 | + def test_filter_classifications_all_pass(self, processor): |
| 59 | + """Test filter_classifications where all pass their thresholds.""" |
| 60 | + classifications = { |
| 61 | + "TAG_A": 0.8, # 0.8 > 0.7 threshold |
| 62 | + "TAG_B": 0.6, # 0.6 > 0.5 threshold |
| 63 | + } |
| 64 | + filtered = processor.filter_classifications(classifications) |
| 65 | + assert len(filtered) == 2 |
| 66 | + assert "TAG_A" in filtered |
| 67 | + assert "TAG_B" in filtered |
| 68 | + |
| 69 | + def test_filter_classifications_some_pass(self, processor): |
| 70 | + """Test filter_classifications where some pass their thresholds.""" |
| 71 | + classifications = { |
| 72 | + "TAG_A": 0.6, # 0.6 < 0.7 threshold |
| 73 | + "TAG_B": 0.6, # 0.6 > 0.5 threshold |
| 74 | + "TAG_C": 0.9, # 0.9 > 0.8 threshold |
| 75 | + } |
| 76 | + filtered = processor.filter_classifications(classifications) |
| 77 | + assert len(filtered) == 2 |
| 78 | + assert "TAG_A" not in filtered |
| 79 | + assert "TAG_B" in filtered |
| 80 | + assert "TAG_C" in filtered |
| 81 | + |
| 82 | + def test_filter_classifications_none_pass(self, processor): |
| 83 | + """Test filter_classifications where none pass their thresholds.""" |
| 84 | + classifications = { |
| 85 | + "TAG_A": 0.6, # 0.6 < 0.7 threshold |
| 86 | + "TAG_C": 0.7, # 0.7 < 0.8 threshold |
| 87 | + } |
| 88 | + filtered = processor.filter_classifications(classifications) |
| 89 | + assert len(filtered) == 0 |
| 90 | + |
| 91 | + def test_filter_classifications_default_threshold(self, processor): |
| 92 | + """Test filter_classifications using default threshold for unknown tags.""" |
| 93 | + classifications = { |
| 94 | + "UNKNOWN_TAG": 0.7, # 0.7 > 0.6 default threshold |
| 95 | + } |
| 96 | + filtered = processor.filter_classifications(classifications) |
| 97 | + assert len(filtered) == 1 |
| 98 | + assert "UNKNOWN_TAG" in filtered |
| 99 | + |
| 100 | + def test_filter_classifications_string_confidence(self, processor): |
| 101 | + """Test filter_classifications with string confidence values.""" |
| 102 | + classifications = { |
| 103 | + "TAG_A": "0.8", # Should convert to float and pass |
| 104 | + "TAG_B": "0.4", # Should convert to float and fail |
| 105 | + } |
| 106 | + filtered = processor.filter_classifications(classifications) |
| 107 | + assert len(filtered) == 1 |
| 108 | + assert "TAG_A" in filtered |
| 109 | + assert "TAG_B" not in filtered |
| 110 | + |
| 111 | + def test_filter_classifications_invalid_confidence(self, processor): |
| 112 | + """Test filter_classifications with invalid confidence values.""" |
| 113 | + classifications = { |
| 114 | + "TAG_A": 0.8, |
| 115 | + "TAG_B": "not a number", # Invalid |
| 116 | + "TAG_C": 0.9, |
| 117 | + } |
| 118 | + filtered = processor.filter_classifications(classifications) |
| 119 | + assert len(filtered) == 2 |
| 120 | + assert "TAG_A" in filtered |
| 121 | + assert "TAG_B" not in filtered |
| 122 | + assert "TAG_C" in filtered |
| 123 | + |
| 124 | + def test_filter_classifications_exact_threshold(self, processor): |
| 125 | + """Test filter_classifications with confidence exactly at threshold.""" |
| 126 | + classifications = { |
| 127 | + "TAG_A": 0.7, # Exactly at threshold (0.7) |
| 128 | + "TAG_B": 0.5, # Exactly at threshold (0.5) |
| 129 | + } |
| 130 | + filtered = processor.filter_classifications(classifications) |
| 131 | + assert len(filtered) == 2 |
| 132 | + assert "TAG_A" in filtered |
| 133 | + assert "TAG_B" in filtered |
| 134 | + |
| 135 | + def test_filter_classifications_empty_dict(self, processor): |
| 136 | + """Test filter_classifications with an empty dictionary.""" |
| 137 | + filtered = processor.filter_classifications({}) |
| 138 | + assert filtered == {} |
0 commit comments