Skip to content

Commit 7e14578

Browse files
Merge pull request #1278 from NASA-IMPACT/1277-class-based-thresholding
1277 class based thresholding
2 parents 6139a18 + 54fba98 commit 7e14578

File tree

6 files changed

+294
-21
lines changed

6 files changed

+294
-21
lines changed

inference/tests/test_classification_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -127,17 +127,6 @@ def test_complex_mappings(self):
127127
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
128128
assert sorted(actual_tags) == sorted(expected_tags)
129129

130-
@patch("django.conf.settings.TDAMM_CLASSIFICATION_THRESHOLD", 0.75)
131-
def test_default_threshold_from_settings(self):
132-
"""Test using the default threshold from settings"""
133-
classification_results = {"Optical": 0.7, "Infrared": 0.8, "X-rays": 0.9}
134-
135-
# With settings threshold of 0.75, Infrared and X-rays should be included
136-
expected_tags = ["MMA_M_EM_I", "MMA_M_EM_X"]
137-
actual_tags = map_classification_to_tdamm_tags(classification_results) # No threshold provided
138-
139-
assert sorted(actual_tags) == sorted(expected_tags)
140-
141130

142131
class TestUpdateUrlWithClassificationResults:
143132
"""Tests for the update_url_with_classification_results function"""
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# inference/tests/test_threshold_processor.py
2+
# docker-compose -f local.yml run --rm django pytest inference/tests/test_threshold_processor.py
3+
4+
from unittest.mock import patch
5+
6+
import pytest
7+
8+
from inference.utils.threshold_processor import ClassificationThresholdProcessor
9+
10+
11+
class TestClassificationThresholdProcessor:
12+
"""Test suite for ClassificationThresholdProcessor class."""
13+
14+
@pytest.fixture
15+
def test_thresholds(self):
16+
"""Test thresholds for generic processor."""
17+
return {"TAG_A": 0.7, "TAG_B": 0.5, "TAG_C": 0.8}
18+
19+
@pytest.fixture
20+
def processor(self, test_thresholds):
21+
"""Create a ClassificationThresholdProcessor with test thresholds."""
22+
return ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)
23+
24+
def test_initialization(self, test_thresholds):
25+
"""Test initialization with provided thresholds."""
26+
processor = ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)
27+
assert processor.thresholds == test_thresholds
28+
assert processor.default_threshold == 0.6
29+
30+
def test_for_tdamm_factory(self):
31+
"""Test the for_tdamm class factory method."""
32+
with patch("inference.utils.threshold_processor.TDAMM_TAG_THRESHOLDS", {"TAG1": 0.7}), patch(
33+
"inference.utils.threshold_processor.DEFAULT_TDAMM_THRESHOLD", 0.5
34+
):
35+
processor = ClassificationThresholdProcessor.for_tdamm()
36+
assert processor.thresholds == {"TAG1": 0.7}
37+
assert processor.default_threshold == 0.5
38+
39+
def test_for_division_factory(self):
40+
"""Test the for_division class factory method."""
41+
with patch("inference.utils.threshold_processor.DIVISION_TAG_THRESHOLDS", {1: 0.7}), patch(
42+
"inference.utils.threshold_processor.DEFAULT_DIVISION_THRESHOLD", 0.5
43+
):
44+
processor = ClassificationThresholdProcessor.for_division()
45+
assert processor.thresholds == {1: 0.7}
46+
assert processor.default_threshold == 0.5
47+
48+
def test_get_threshold_exact_match(self, processor):
49+
"""Test get_threshold with an exact tag match."""
50+
assert processor.get_threshold("TAG_A") == 0.7
51+
assert processor.get_threshold("TAG_B") == 0.5
52+
assert processor.get_threshold("TAG_C") == 0.8
53+
54+
def test_get_threshold_no_match(self, processor):
55+
"""Test get_threshold with a tag that doesn't exist."""
56+
assert processor.get_threshold("UNKNOWN_TAG") == 0.6 # default threshold
57+
58+
def test_filter_classifications_all_pass(self, processor):
59+
"""Test filter_classifications where all pass their thresholds."""
60+
classifications = {
61+
"TAG_A": 0.8, # 0.8 > 0.7 threshold
62+
"TAG_B": 0.6, # 0.6 > 0.5 threshold
63+
}
64+
filtered = processor.filter_classifications(classifications)
65+
assert len(filtered) == 2
66+
assert "TAG_A" in filtered
67+
assert "TAG_B" in filtered
68+
69+
def test_filter_classifications_some_pass(self, processor):
70+
"""Test filter_classifications where some pass their thresholds."""
71+
classifications = {
72+
"TAG_A": 0.6, # 0.6 < 0.7 threshold
73+
"TAG_B": 0.6, # 0.6 > 0.5 threshold
74+
"TAG_C": 0.9, # 0.9 > 0.8 threshold
75+
}
76+
filtered = processor.filter_classifications(classifications)
77+
assert len(filtered) == 2
78+
assert "TAG_A" not in filtered
79+
assert "TAG_B" in filtered
80+
assert "TAG_C" in filtered
81+
82+
def test_filter_classifications_none_pass(self, processor):
83+
"""Test filter_classifications where none pass their thresholds."""
84+
classifications = {
85+
"TAG_A": 0.6, # 0.6 < 0.7 threshold
86+
"TAG_C": 0.7, # 0.7 < 0.8 threshold
87+
}
88+
filtered = processor.filter_classifications(classifications)
89+
assert len(filtered) == 0
90+
91+
def test_filter_classifications_default_threshold(self, processor):
92+
"""Test filter_classifications using default threshold for unknown tags."""
93+
classifications = {
94+
"UNKNOWN_TAG": 0.7, # 0.7 > 0.6 default threshold
95+
}
96+
filtered = processor.filter_classifications(classifications)
97+
assert len(filtered) == 1
98+
assert "UNKNOWN_TAG" in filtered
99+
100+
def test_filter_classifications_string_confidence(self, processor):
101+
"""Test filter_classifications with string confidence values."""
102+
classifications = {
103+
"TAG_A": "0.8", # Should convert to float and pass
104+
"TAG_B": "0.4", # Should convert to float and fail
105+
}
106+
filtered = processor.filter_classifications(classifications)
107+
assert len(filtered) == 1
108+
assert "TAG_A" in filtered
109+
assert "TAG_B" not in filtered
110+
111+
def test_filter_classifications_invalid_confidence(self, processor):
112+
"""Test filter_classifications with invalid confidence values."""
113+
classifications = {
114+
"TAG_A": 0.8,
115+
"TAG_B": "not a number", # Invalid
116+
"TAG_C": 0.9,
117+
}
118+
filtered = processor.filter_classifications(classifications)
119+
assert len(filtered) == 2
120+
assert "TAG_A" in filtered
121+
assert "TAG_B" not in filtered
122+
assert "TAG_C" in filtered
123+
124+
def test_filter_classifications_exact_threshold(self, processor):
125+
"""Test filter_classifications with confidence exactly at threshold."""
126+
classifications = {
127+
"TAG_A": 0.7, # Exactly at threshold (0.7)
128+
"TAG_B": 0.5, # Exactly at threshold (0.5)
129+
}
130+
filtered = processor.filter_classifications(classifications)
131+
assert len(filtered) == 2
132+
assert "TAG_A" in filtered
133+
assert "TAG_B" in filtered
134+
135+
def test_filter_classifications_empty_dict(self, processor):
136+
"""Test filter_classifications with an empty dictionary."""
137+
filtered = processor.filter_classifications({})
138+
assert filtered == {}

inference/utils/classification_utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.conf import settings
22

3+
from inference.utils.threshold_processor import ClassificationThresholdProcessor
34
from sde_collections.models.collection_choice_fields import TDAMMTags
45

56

@@ -18,6 +19,9 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
1819
if threshold is None:
1920
threshold = float(getattr(settings, "TDAMM_CLASSIFICATION_THRESHOLD"))
2021

22+
# Initialize the threshold processor
23+
threshold_processor = ClassificationThresholdProcessor.for_tdamm()
24+
2125
selected_tags = []
2226

2327
# Build a mapping from simplified tag names to actual TDAMMTags values
@@ -35,30 +39,36 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
3539
tag_mapping["supernovae"] = tag_value
3640

3741
# Process classification results
42+
tdamm_confidences = {}
3843
for classification_key, confidence in classification_results.items():
3944
if isinstance(confidence, str):
4045
try:
4146
confidence = float(confidence)
4247
except (ValueError, TypeError):
4348
continue
4449

45-
if confidence < threshold:
46-
continue
47-
4850
# Normalize the classification key
4951
normalized_key = classification_key.lower()
52+
tag_value = None
5053

5154
# Try to find a match in our mapping
5255
if normalized_key in tag_mapping:
53-
selected_tags.append(tag_mapping[normalized_key])
56+
tag_value = tag_mapping[normalized_key]
5457
else:
55-
# Try partial matching for more complex cases
56-
for tag_key, tag_value in tag_mapping.items():
57-
if tag_key in normalized_key or normalized_key in tag_key:
58-
selected_tags.append(tag_value)
58+
# Try partial matching
59+
for key, value in tag_mapping.items():
60+
if key in normalized_key or normalized_key in key:
61+
tag_value = value
5962
break
6063

61-
return selected_tags
64+
# Skip if no matching tag found
65+
if not tag_value:
66+
continue
67+
68+
tdamm_confidences[tag_value] = confidence
69+
70+
selected_tags = threshold_processor.filter_classifications(tdamm_confidences)
71+
return list(selected_tags.keys())
6272

6373

6474
def update_url_with_classification_results(url_object, classification_results, threshold=None):

inference/utils/config.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Configuration settings for classification thresholds."""
2+
3+
# Configuration settings for TDAMM tag classification thresholds
4+
# Format: "tag_name": threshold_value (0.0 to 1.0)
5+
TDAMM_TAG_THRESHOLDS = {
6+
"NOT_TDAMM": 0.7, # non-TDAMM
7+
"MMA_O_BH_AGN": 0.5, # Active Galactic Nuclei
8+
"MMA_O_BI_BBH": 0.7, # Binary Black Holes
9+
"MMA_O_BI_BNS": 0.6, # Binary Neutron Stars
10+
"MMA_O_BI_B": 0.7, # Binary Pulsars
11+
"MMA_M_G_B": 0.5, # Burst
12+
"MMA_O_BI_C": 0.7, # Cataclysmic Variables
13+
"MMA_M_G_CBI": 0.5, # Compact Binary Inspiral
14+
"MMA_M_G_CON": 0.5, # Continuous
15+
"MMA_M_C": 0.8, # Cosmic Rays
16+
"MMA_O_E": 0.7, # Exoplanets
17+
"MMA_S_FBOT": 0.7, # Fast Blue Optical Transients
18+
"MMA_S_F": 0.7, # Fast Radio Bursts
19+
"MMA_M_EM_G": 0.5, # Gamma rays
20+
"MMA_S_G": 0.8, # Gamma-ray Bursts
21+
"MMA_M_EM_I": 0.8, # Infrared
22+
"MMA_O_BH_IM": 0.5, # Intermediate Mass
23+
"MMA_S_K": 0.5, # Kilonovae
24+
"MMA_O_N_M": 0.7, # Magnetars
25+
"MMA_M_N": 0.5, # Neutrinos
26+
"MMA_O_BI_N": 0.5, # Neutron Star-Black Hole
27+
"MMA_S_N": 0.8, # Novae
28+
"MMA_M_EM_O": 0.7, # Optical
29+
"MMA_S_P": 0.5, # Pevatrons
30+
"MMA_O_N_PWN": 0.8, # Pulsar Wind Nebulae
31+
"MMA_O_N_P": 0.5, # Pulsars
32+
"MMA_M_EM_R": 0.8, # Radio
33+
"MMA_O_BH_STM": 0.5, # Stellar Mass
34+
"MMA_S_ST": 0.7, # Stellar flares
35+
"MMA_M_G_S": 0.8, # Stochastic
36+
"MMA_S_SU": 0.8, # SuperNovae
37+
"MMA_O_BH_SUM": 0.5, # Supermassive
38+
"MMA_O_S": 0.6, # Supernova Remnants
39+
"MMA_M_EM_U": 0.7, # Ultraviolet
40+
"MMA_O_BI_W": 0.7, # White Dwarf Binaries
41+
"MMA_M_EM_X": 0.8, # X-rays
42+
}
43+
44+
# Default threshold to use if a specific tag isn't defined above
45+
DEFAULT_TDAMM_THRESHOLD = 0.5
46+
47+
# Threshold values for different Division classifications
48+
DIVISION_TAG_THRESHOLDS = {
49+
"Astrophysics": 0.5,
50+
"Biological and Physical Sciences": 0.5,
51+
"Earth Science": 0.5,
52+
"Heliophysics": 0.5,
53+
"Planetary Science": 0.5,
54+
"General": 0.5,
55+
}
56+
57+
# Default threshold for Division classification
58+
DEFAULT_DIVISION_THRESHOLD = 0.5
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Module for processing classifications with tag-specific thresholds."""
2+
3+
from inference.utils.config import (
4+
DEFAULT_DIVISION_THRESHOLD,
5+
DEFAULT_TDAMM_THRESHOLD,
6+
DIVISION_TAG_THRESHOLDS,
7+
TDAMM_TAG_THRESHOLDS,
8+
)
9+
10+
11+
class ClassificationThresholdProcessor:
12+
"""
13+
Generic processor for classifications using tag-specific thresholds.
14+
Can be used with any classification system where different classes
15+
need different confidence thresholds.
16+
"""
17+
18+
def __init__(self, thresholds: dict[str, float], default_threshold: float = 0.5):
19+
"""
20+
Initialize the processor with thresholds.
21+
22+
Args:
23+
thresholds: Dictionary of classification tags and their threshold values.
24+
default_threshold: Default threshold to use if tag isn't in thresholds.
25+
"""
26+
self.thresholds = thresholds
27+
self.default_threshold = default_threshold
28+
29+
@classmethod
30+
def for_tdamm(cls):
31+
"""Create a processor for TDAMM classification."""
32+
return cls(TDAMM_TAG_THRESHOLDS, DEFAULT_TDAMM_THRESHOLD)
33+
34+
@classmethod
35+
def for_division(cls):
36+
"""Create a processor for Division classification."""
37+
return cls(DIVISION_TAG_THRESHOLDS, DEFAULT_DIVISION_THRESHOLD)
38+
39+
def get_threshold(self, tag: str) -> float:
40+
"""
41+
Get the threshold for a tag.
42+
43+
Args:
44+
tag: The tag to get threshold for
45+
46+
Returns:
47+
The threshold value as a float
48+
"""
49+
return self.thresholds.get(tag, self.default_threshold)
50+
51+
def filter_classifications(self, classifications: dict[str, float | str]) -> dict[str, float]:
52+
"""
53+
Filter classifications based on their thresholds.
54+
55+
Args:
56+
classifications: Dictionary with classification keys and confidence scores
57+
58+
Returns:
59+
Dictionary with classifications that passed their thresholds
60+
"""
61+
result = {}
62+
for key, confidence in classifications.items():
63+
# Convert confidence to float if it's a string
64+
if isinstance(confidence, str):
65+
try:
66+
confidence_value = float(confidence)
67+
except (ValueError, TypeError):
68+
continue
69+
else:
70+
confidence_value = confidence
71+
72+
# Get the threshold for this classification
73+
threshold = self.get_threshold(key)
74+
75+
# Keep only classifications that meet their threshold
76+
if confidence_value >= threshold:
77+
result[key] = confidence_value
78+
79+
return result

production.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ services:
5858
image: sde_indexing_helper_production_celerybeat
5959
container_name: sde_indexing_helper_production_celerybeat
6060
depends_on:
61-
- awscli
6261
- postgres
6362
ports: []
6463
command: /start-celerybeat

0 commit comments

Comments
 (0)