Skip to content

Commit ab9bd9d

Browse files
committed
Added tests for classification utils
1 parent 630363b commit ab9bd9d

File tree

1 file changed

+249
-0
lines changed

1 file changed

+249
-0
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# inference/tests/test_classification_utils.py
2+
# docker-compose -f local.yml run --rm django pytest inference/tests/test_classification_utils.py
3+
4+
from unittest.mock import Mock, patch
5+
6+
import pytest
7+
8+
from inference.utils.classification_utils import (
9+
map_classification_to_tdamm_tags,
10+
update_url_with_classification_results,
11+
)
12+
13+
14+
class TestMapClassificationToTDAMMTags:
15+
"""Tests for the map_classification_to_tdamm_tags function"""
16+
17+
def test_basic_mapping(self):
18+
"""Test basic mapping of classification results to TDAMM tags"""
19+
classification_results = {"Optical": 0.9, "Infrared": 0.85, "X-rays": 0.95}
20+
21+
expected_tags = [
22+
"MMA_M_EM_O", # Optical
23+
"MMA_M_EM_I", # Infrared
24+
"MMA_M_EM_X", # X-rays
25+
]
26+
27+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
28+
assert sorted(actual_tags) == sorted(expected_tags)
29+
30+
def test_threshold_handling(self):
31+
"""Test that only tags above the threshold are included"""
32+
classification_results = {
33+
"Optical": 0.9, # Above threshold
34+
"Infrared": 0.55, # Below threshold
35+
"X-rays": 0.7, # Below threshold
36+
"Radio": 0.85, # Above threshold
37+
}
38+
39+
expected_tags = [
40+
"MMA_M_EM_O", # Optical
41+
"MMA_M_EM_R", # Radio
42+
]
43+
44+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
45+
assert sorted(actual_tags) == sorted(expected_tags)
46+
47+
def test_case_insensitivity(self):
48+
"""Test that the mapping works regardless of case"""
49+
classification_results = {
50+
"optical": 0.9, # Lowercase
51+
"INFRARED": 0.85, # Uppercase
52+
"X-Rays": 0.95, # Mixed case
53+
}
54+
55+
expected_tags = [
56+
"MMA_M_EM_O", # Optical
57+
"MMA_M_EM_I", # Infrared
58+
"MMA_M_EM_X", # X-rays
59+
]
60+
61+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
62+
assert sorted(actual_tags) == sorted(expected_tags)
63+
64+
def test_special_cases(self):
65+
"""Test special case mappings"""
66+
classification_results = {
67+
"non-TDAMM": 0.95,
68+
"supernovae": 0.9,
69+
}
70+
71+
expected_tags = [
72+
"NOT_TDAMM",
73+
"MMA_S_SU",
74+
]
75+
76+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
77+
assert sorted(actual_tags) == sorted(expected_tags)
78+
79+
def test_string_confidence_values(self):
80+
"""Test handling string confidence values"""
81+
classification_results = {"Optical": "0.9", "Infrared": 0.85, "X-rays": "0.95"} # String # Float # String
82+
83+
expected_tags = [
84+
"MMA_M_EM_O", # Optical
85+
"MMA_M_EM_I", # Infrared
86+
"MMA_M_EM_X", # X-rays
87+
]
88+
89+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
90+
assert sorted(actual_tags) == sorted(expected_tags)
91+
92+
def test_invalid_confidence_values(self):
93+
"""Test handling invalid confidence values"""
94+
classification_results = {"Optical": 0.9, "Infrared": "not_a_number", "X-rays": 0.95}
95+
96+
expected_tags = [
97+
"MMA_M_EM_O", # Optical
98+
"MMA_M_EM_X", # X-rays
99+
]
100+
101+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
102+
assert sorted(actual_tags) == sorted(expected_tags)
103+
104+
def test_empty_classification_results(self):
105+
"""Test handling of empty classification results"""
106+
classification_results = {}
107+
108+
actual_tags = map_classification_to_tdamm_tags(classification_results)
109+
assert actual_tags == []
110+
111+
def test_complex_mappings(self):
112+
"""Test more complex mappings with specific TDAMM categories"""
113+
classification_results = {
114+
"Binary Black Holes": 0.9,
115+
"Neutron Star-Black Hole": 0.85,
116+
"Gamma-ray Bursts": 0.95,
117+
"Fast Blue Optical Transients": 0.8,
118+
}
119+
120+
expected_tags = [
121+
"MMA_O_BI_BBH", # Binary Black Holes
122+
"MMA_O_BI_N", # Neutron Star-Black Hole
123+
"MMA_S_G", # Gamma-ray Bursts
124+
"MMA_S_FBOT", # Fast Blue Optical Transients
125+
]
126+
127+
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
128+
assert sorted(actual_tags) == sorted(expected_tags)
129+
130+
@patch("django.conf.settings.TDAMM_CLASSIFICATION_THRESHOLD", 0.75)
131+
def test_default_threshold_from_settings(self):
132+
"""Test using the default threshold from settings"""
133+
classification_results = {"Optical": 0.7, "Infrared": 0.8, "X-rays": 0.9}
134+
135+
# With settings threshold of 0.75, Infrared and X-rays should be included
136+
expected_tags = ["MMA_M_EM_I", "MMA_M_EM_X"]
137+
actual_tags = map_classification_to_tdamm_tags(classification_results) # No threshold provided
138+
139+
assert sorted(actual_tags) == sorted(expected_tags)
140+
141+
142+
class TestUpdateUrlWithClassificationResults:
143+
"""Tests for the update_url_with_classification_results function"""
144+
145+
@pytest.fixture
146+
def mock_url(self):
147+
"""Create a mock URL object for testing"""
148+
url = Mock()
149+
url.tdamm_tag_ml = None
150+
url.save = Mock()
151+
return url
152+
153+
@patch("inference.utils.classification_utils.map_classification_to_tdamm_tags")
154+
def test_update_url_properly_calls_mapping(self, mock_map_function, mock_url):
155+
"""Test that URL objects are correctly updated with TDAMM tags"""
156+
# Set up mock return value
157+
mock_tdamm_tags = ["MMA_M_EM_O", "MMA_M_EM_X"]
158+
mock_map_function.return_value = mock_tdamm_tags
159+
160+
# Test data
161+
classification_results = {"Optical": 0.9, "X-rays": 0.85}
162+
163+
# Call the function
164+
result = update_url_with_classification_results(mock_url, classification_results)
165+
166+
# Verify map_classification_to_tdamm_tags was called properly
167+
mock_map_function.assert_called_once_with(classification_results)
168+
169+
# Verify URL object was updated correctly
170+
assert mock_url.tdamm_tag_ml == mock_tdamm_tags
171+
mock_url.save.assert_called_once_with(update_fields=["tdamm_tag_ml"])
172+
173+
# Verify return value
174+
assert result == mock_tdamm_tags
175+
176+
@patch("inference.utils.classification_utils.map_classification_to_tdamm_tags")
177+
def test_threshold_parameter_behavior(self, mock_map_function, mock_url):
178+
"""Test how threshold parameter is handled"""
179+
mock_tdamm_tags = ["MMA_M_EM_O"]
180+
mock_map_function.return_value = mock_tdamm_tags
181+
182+
classification_results = {"Optical": 0.9}
183+
custom_threshold = 0.85
184+
185+
update_url_with_classification_results(mock_url, classification_results, threshold=custom_threshold)
186+
187+
# Based on the implementation, the function doesn't pass the threshold parameter
188+
mock_map_function.assert_called_once_with(classification_results)
189+
190+
def test_integration_with_real_mapping(self, mock_url):
191+
"""Test end-to-end integration with real mapping function"""
192+
classification_results = {"Optical": 0.9, "Binary Black Holes": 0.85, "Novae": 0.8}
193+
194+
expected_tags = ["MMA_M_EM_O", "MMA_O_BI_BBH", "MMA_S_N"]
195+
196+
result = update_url_with_classification_results(mock_url, classification_results, threshold=0.7)
197+
198+
assert sorted(result) == sorted(expected_tags)
199+
assert sorted(mock_url.tdamm_tag_ml) == sorted(expected_tags)
200+
201+
def test_full_mapping_coverage(self):
202+
"""Test that all provided mappings work correctly"""
203+
mapping = {
204+
"Optical": "MMA_M_EM_O",
205+
"Ultraviolet": "MMA_M_EM_U",
206+
"Exoplanets": "MMA_O_E",
207+
"Gamma rays": "MMA_M_EM_G",
208+
"Infrared": "MMA_M_EM_I",
209+
"Gamma-ray Bursts": "MMA_S_G",
210+
"SuperNovae": "MMA_S_SU",
211+
"non-TDAMM": "NOT_TDAMM",
212+
"Radio": "MMA_M_EM_R",
213+
"White Dwarf Binaries": "MMA_O_BI_W",
214+
"Pulsar Wind Nebulae": "MMA_O_N_PWN",
215+
"X-rays": "MMA_M_EM_X",
216+
"Compact Binary Inspiral": "MMA_M_G_CBI",
217+
"Stochastic": "MMA_M_G_S",
218+
"Continuous": "MMA_M_G_CON",
219+
"Supernova Remnants": "MMA_O_S",
220+
"Stellar flares": "MMA_S_ST",
221+
"Pulsars": "MMA_O_N_P",
222+
"Neutron Star-Black Hole": "MMA_O_BI_N",
223+
"Cosmic Rays": "MMA_M_C",
224+
"Binary Black Holes": "MMA_O_BI_BBH",
225+
"Burst": "MMA_M_G_B",
226+
"Binary Neutron Stars": "MMA_O_BI_BNS",
227+
"Fast Blue Optical Transients": "MMA_S_FBOT",
228+
"Cataclysmic Variables": "MMA_O_BI_C",
229+
"Binary Pulsars": "MMA_O_BI_B",
230+
"Active Galactic Nuclei": "MMA_O_BH_AGN",
231+
"Neutrinos": "MMA_M_N",
232+
"Fast Radio Bursts": "MMA_S_F",
233+
"Stellar Mass": "MMA_O_BH_STM",
234+
"Magnetars": "MMA_O_N_M",
235+
"Pevatrons": "MMA_S_P",
236+
"Novae": "MMA_S_N",
237+
"Kilonovae": "MMA_S_K",
238+
"Supermassive": "MMA_O_BH_SUM",
239+
"Intermediate Mass": "MMA_O_BH_IM",
240+
}
241+
242+
# Create classification results with all keys
243+
classification_results = {key: 1.0 for key in mapping.keys()}
244+
245+
# Map to TDAMM tags
246+
tdamm_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.5)
247+
248+
# Verify all expected tags are present
249+
assert sorted(tdamm_tags) == sorted(list(mapping.values()))

0 commit comments

Comments
 (0)