Skip to content

Commit 403c71d

Browse files
committed
update threshold processing to better handle not ej cases
1 parent 6a1906f commit 403c71d

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

scripts/ej/threshold_processing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ def process_predictions(self, predictions: list[dict[str, float]]) -> list[str]:
2727
2828
Args:
2929
predictions: List of dictionaries containing prediction labels and scores.
30-
Each dict should have 'label' and 'score' keys.
30+
Each dict should have 'label' and 'score' keys.
3131
3232
Returns:
3333
List of classification labels that meet their respective thresholds.
3434
"""
35+
# Handle empty predictions
36+
if not predictions:
37+
return ["Not EJ"]
38+
3539
# Find highest scoring prediction
3640
highest_prediction = max(predictions, key=lambda x: x["score"])
3741

@@ -43,7 +47,11 @@ def process_predictions(self, predictions: list[dict[str, float]]) -> list[str]:
4347
classifications = [
4448
pred["label"]
4549
for pred in predictions
46-
if (pred["score"] >= self.thresholds[pred["label"]] and pred["label"] != "Not EJ")
50+
if (
51+
pred["label"] in self.thresholds # Only check labels we have thresholds for
52+
and pred["score"] >= self.thresholds[pred["label"]]
53+
and pred["label"] != "Not EJ"
54+
)
4755
]
4856

4957
# Default to "Not EJ" if no classifications meet thresholds

0 commit comments

Comments
 (0)