Skip to content

Commit c581f83

Browse files
committed
refactor classification thresholding
1 parent 50f857b commit c581f83

File tree

1 file changed

+75
-47
lines changed

1 file changed

+75
-47
lines changed

scripts/ej/create_ej_dump.py

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,100 @@
11
"""
22
inferences are supplied by the classification model. the contact point is Bishwas
33
cmr is supplied by running https://github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
4-
move to the serve like this: scp scripts/ej/ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
4+
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
55
"""
66

77
import json
88
from datetime import datetime
99

10-
inferences = json.load(open("cmr-inference.json"))
11-
cmr = json.load(open("cmr_collections_umm_20240807_142146.json"))
1210

11+
def load_json_file(file_path: str) -> dict:
12+
with open(file_path, "r") as file:
13+
return json.load(file)
1314

14-
def process_classifications(data: dict[str, any], threshold: float = 0.5) -> dict[str, any]:
15+
16+
def save_to_json(data: dict | list, file_path: str) -> None:
17+
with open(file_path, "w") as file:
18+
json.dump(data, file, indent=2)
19+
20+
21+
def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
1522
"""
16-
Takes a classification dict as input and processes them as follows:
17-
1. If 'Not EJ' is the highest scoring prediction, it returns 'Not EJ' as the only classification.
18-
2. If 'Not EJ' is not the highest, it filters the classifications based on the provided threshold, excluding Not EJ.
19-
3. If no classifications pass the threshold, it defaults to 'EJ'.
23+
Process the predictions and classify as follows:
24+
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
25+
2. Filter classifications based on the threshold, excluding 'Not EJ'
26+
3. Default to 'Not EJ' if no classifications meet the threshold
2027
"""
21-
predictions = data["predictions"]
22-
23-
# Sort predictions by score in descending order and get the highest
24-
highest_prediction = sorted(predictions, key=lambda x: x["score"], reverse=True)[0]
28+
highest_prediction = max(predictions, key=lambda x: x["score"])
2529

26-
# Determine classifications based on the conditions
2730
if highest_prediction["label"] == "Not EJ":
28-
classifications = ["Not EJ"]
29-
else:
30-
classifications = [
31-
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
32-
]
33-
if not classifications:
34-
classifications = ["EJ"]
31+
return ["Not EJ"]
32+
33+
classifications = [
34+
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
35+
]
36+
37+
return classifications if classifications else ["Not EJ"]
38+
39+
40+
def create_cmr_dict(cmr_data: list[dict[str, dict[str, str]]]) -> dict[str, dict[str, dict[str, str]]]:
41+
"""Restructure CMR data into a dictionary with 'concept-id' as the key."""
42+
return {dataset["meta"]["concept-id"]: dataset for dataset in cmr_data}
43+
44+
45+
def remove_unauthorized_classifications(classifications: list[str]) -> list[str]:
46+
"""Filter classifications to keep only those in the authorized list."""
47+
48+
authorized_classifications = [
49+
"Climate Change",
50+
"Disasters",
51+
"Extreme Heat",
52+
"Food Availability",
53+
"Health & Air Quality",
54+
"Human Dimensions",
55+
"Urban Flooding",
56+
"Water Availability",
57+
]
58+
59+
return [cls for cls in classifications if cls in authorized_classifications]
60+
61+
62+
def update_cmr_with_classifications(
63+
inferences: list[dict[str, dict]],
64+
cmr_dict: dict[str, dict[str, dict]],
65+
threshold: float = 0.5,
66+
) -> list[dict[str, dict]]:
67+
"""Update CMR data with valid classifications based on inferences."""
68+
69+
predicted_cmr = []
70+
71+
for inference in inferences:
72+
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
73+
classifications = remove_unauthorized_classifications(classifications)
3574

36-
return classifications
75+
if classifications:
76+
cmr_dataset = cmr_dict.get(inference["concept-id"])
3777

78+
if cmr_dataset:
79+
cmr_dataset["indicators"] = ";".join(classifications)
80+
predicted_cmr.append(cmr_dataset)
3881

39-
# restructure cmr dump to be a dictionary with concept-id as key
40-
cmr_dict = {dataset["meta"]["concept-id"]: dataset for dataset in cmr}
82+
return predicted_cmr
4183

42-
predicted_cmr = []
4384

44-
authorized_classifications = [
45-
"Climate Change",
46-
"Disasters",
47-
"Extreme Heat",
48-
"Food Availability",
49-
"Health & Air Quality",
50-
"Human Dimensions",
51-
"Urban Flooding",
52-
"Water Availability",
53-
]
85+
def main():
86+
inferences = load_json_file("cmr-inference.json")
87+
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")
5488

89+
cmr_dict = create_cmr_dict(cmr)
5590

56-
for inference in inferences:
57-
classifications = process_classifications(inference)
58-
if classifications == ["Not EJ"]:
59-
continue
91+
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
6092

61-
# Filter classifications to keep only those in the authorized list
62-
classifications = [cls for cls in classifications if cls in authorized_classifications]
93+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
94+
file_name = f"ej_dump_{timestamp}.json"
6395

64-
cmr_dataset = cmr_dict.get(inference["concept-id"])
65-
if cmr_dataset:
66-
cmr_dataset["indicators"] = ";".join(classifications)
67-
predicted_cmr.append(cmr_dataset)
96+
save_to_json(predicted_cmr, file_name)
6897

69-
timestamp: str = datetime.now().strftime("%Y%m%d_%H%M%S")
70-
file_name: str = f"ej_dump_{timestamp}.json"
7198

72-
json.dump(predicted_cmr, open(file_name, "w"), indent=2)
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)