Skip to content

Commit d537302

Browse files
committed
add per indicator thrsholding and new dump
1 parent 88eeb0d commit d537302

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

scripts/ej/cmr_to_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def categorize_processing_level(level):
6969
# remove existing data
7070
EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete()
7171

72-
ej_dump = json.load(open("backups/ej_dump_20240815_112916.json"))
72+
ej_dump = json.load(open("backups/ej_dump_20241017_133151.json.json"))
7373
for dataset in ej_dump:
7474
ej_row = EnvironmentalJusticeRow(
7575
destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV,

scripts/ej/create_ej_dump.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
inferences are supplied by the classification model. the contact point is Bishwas
33
cmr is supplied by running
44
github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
5-
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
5+
move to the server like this: scp ej_dump_20241017_133151.json sde:/home/ec2-user/sde_indexing_helper/backups/
66
"""
77

88
import json
@@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
1919
json.dump(data, file, indent=2)
2020

2121

22-
def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
22+
def process_classifications(predictions: list[dict[str, float]], thresholds: dict[str, float]) -> list[str]:
2323
"""
24-
Process the predictions and classify as follows:
25-
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
26-
2. Filter classifications based on the threshold, excluding 'Not EJ'
27-
3. Default to 'Not EJ' if no classifications meet the threshold
24+
Process the predictions and classify based on the individual thresholds per indicator:
25+
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
26+
2. Filter classifications based on their individual thresholds, excluding 'Not EJ'.
27+
3. Default to 'Not EJ' if no classifications meet the threshold.
2828
"""
2929
highest_prediction = max(predictions, key=lambda x: x["score"])
3030

3131
if highest_prediction["label"] == "Not EJ":
3232
return ["Not EJ"]
3333

3434
classifications = [
35-
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
35+
pred["label"]
36+
for pred in predictions
37+
if pred["score"] >= thresholds[pred["label"]] and pred["label"] != "Not EJ"
3638
]
3739

3840
return classifications if classifications else ["Not EJ"]
@@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
6365
def update_cmr_with_classifications(
6466
inferences: list[dict[str, dict]],
6567
cmr_dict: dict[str, dict[str, dict]],
66-
threshold: float = 0.5,
68+
thresholds: dict[str, float],
6769
) -> list[dict[str, dict]]:
6870
"""Update CMR data with valid classifications based on inferences."""
6971

7072
predicted_cmr = []
7173

7274
for inference in inferences:
73-
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
75+
classifications = process_classifications(predictions=inference["predictions"], thresholds=thresholds)
7476
classifications = remove_unauthorized_classifications(classifications)
7577

7678
if classifications:
@@ -84,17 +86,30 @@ def update_cmr_with_classifications(
8486

8587

8688
def main():
87-
inferences = load_json_file("cmr-inference.json")
89+
thresholds = {
90+
"Not EJ": 0.80,
91+
"Climate Change": 0.95,
92+
"Disasters": 0.80,
93+
"Extreme Heat": 0.50,
94+
"Food Availability": 0.80,
95+
"Health & Air Quality": 0.90,
96+
"Human Dimensions": 0.80,
97+
"Urban Flooding": 0.50,
98+
"Water Availability": 0.80,
99+
}
100+
101+
inferences = load_json_file("alpha-1.3-wise-vortex-42-predictions.json")
88102
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")
89103

90104
cmr_dict = create_cmr_dict(cmr)
91105

92-
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
106+
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, thresholds=thresholds)
93107

94108
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
95109
file_name = f"ej_dump_{timestamp}.json"
96110

97111
save_to_json(predicted_cmr, file_name)
112+
print(f"Saved to {file_name}")
98113

99114

100115
if __name__ == "__main__":

0 commit comments

Comments
 (0)