1
1
from django .conf import settings
2
2
3
+ from inference .utils .threshold_processor import ClassificationThresholdProcessor
3
4
from sde_collections .models .collection_choice_fields import TDAMMTags
4
5
5
6
@@ -18,6 +19,9 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
18
19
if threshold is None :
19
20
threshold = float (getattr (settings , "TDAMM_CLASSIFICATION_THRESHOLD" ))
20
21
22
+ # Initialize the threshold processor
23
+ threshold_processor = ClassificationThresholdProcessor .for_tdamm ()
24
+
21
25
selected_tags = []
22
26
23
27
# Build a mapping from simplified tag names to actual TDAMMTags values
@@ -35,30 +39,36 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
35
39
tag_mapping ["supernovae" ] = tag_value
36
40
37
41
# Process classification results
42
+ tdamm_confidences = {}
38
43
for classification_key , confidence in classification_results .items ():
39
44
if isinstance (confidence , str ):
40
45
try :
41
46
confidence = float (confidence )
42
47
except (ValueError , TypeError ):
43
48
continue
44
49
45
- if confidence < threshold :
46
- continue
47
-
48
50
# Normalize the classification key
49
51
normalized_key = classification_key .lower ()
52
+ tag_value = None
50
53
51
54
# Try to find a match in our mapping
52
55
if normalized_key in tag_mapping :
53
- selected_tags . append ( tag_mapping [normalized_key ])
56
+ tag_value = tag_mapping [normalized_key ]
54
57
else :
55
- # Try partial matching for more complex cases
56
- for tag_key , tag_value in tag_mapping .items ():
57
- if tag_key in normalized_key or normalized_key in tag_key :
58
- selected_tags . append ( tag_value )
58
+ # Try partial matching
59
+ for key , value in tag_mapping .items ():
60
+ if key in normalized_key or normalized_key in key :
61
+ tag_value = value
59
62
break
60
63
61
- return selected_tags
64
+ # Skip if no matching tag found
65
+ if not tag_value :
66
+ continue
67
+
68
+ tdamm_confidences [tag_value ] = confidence
69
+
70
+ selected_tags = threshold_processor .filter_classifications (tdamm_confidences )
71
+ return list (selected_tags .keys ())
62
72
63
73
64
74
def update_url_with_classification_results (url_object , classification_results , threshold = None ):
0 commit comments