Skip to content

Commit 843dd00

Browse files
committed
Updated classification_utils to use the threshold_processor
1 parent d39f768 commit 843dd00

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

inference/utils/classification_utils.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.conf import settings
22

3+
from inference.utils.threshold_processor import ClassificationThresholdProcessor
34
from sde_collections.models.collection_choice_fields import TDAMMTags
45

56

@@ -18,6 +19,9 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
1819
if threshold is None:
1920
threshold = float(getattr(settings, "TDAMM_CLASSIFICATION_THRESHOLD"))
2021

22+
# Initialize the threshold processor
23+
threshold_processor = ClassificationThresholdProcessor.for_tdamm()
24+
2125
selected_tags = []
2226

2327
# Build a mapping from simplified tag names to actual TDAMMTags values
@@ -35,30 +39,36 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
3539
tag_mapping["supernovae"] = tag_value
3640

3741
# Process classification results
42+
tdamm_confidences = {}
3843
for classification_key, confidence in classification_results.items():
3944
if isinstance(confidence, str):
4045
try:
4146
confidence = float(confidence)
4247
except (ValueError, TypeError):
4348
continue
4449

45-
if confidence < threshold:
46-
continue
47-
4850
# Normalize the classification key
4951
normalized_key = classification_key.lower()
52+
tag_value = None
5053

5154
# Try to find a match in our mapping
5255
if normalized_key in tag_mapping:
53-
selected_tags.append(tag_mapping[normalized_key])
56+
tag_value = tag_mapping[normalized_key]
5457
else:
55-
# Try partial matching for more complex cases
56-
for tag_key, tag_value in tag_mapping.items():
57-
if tag_key in normalized_key or normalized_key in tag_key:
58-
selected_tags.append(tag_value)
58+
# Try partial matching
59+
for key, value in tag_mapping.items():
60+
if key in normalized_key or normalized_key in key:
61+
tag_value = value
5962
break
6063

61-
return selected_tags
64+
# Skip if no matching tag found
65+
if not tag_value:
66+
continue
67+
68+
tdamm_confidences[tag_value] = confidence
69+
70+
selected_tags = threshold_processor.filter_classifications(tdamm_confidences)
71+
return list(selected_tags.keys())
6272

6373

6474
def update_url_with_classification_results(url_object, classification_results, threshold=None):

0 commit comments

Comments
 (0)