2
2
inferences are supplied by the classification model. the contact point is Bishwas
3
3
cmr is supplied by running
4
4
github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
5
- move to the serve like this: scp ej_dump_20240814_143036 .json sde:/home/ec2-user/sde_indexing_helper/backups/
5
+ move to the server like this: scp ej_dump_20241017_133151 .json sde:/home/ec2-user/sde_indexing_helper/backups/
6
6
"""
7
7
8
8
import json
@@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
19
19
json .dump (data , file , indent = 2 )
20
20
21
21
22
- def process_classifications (predictions : list [dict [str , float ]], threshold : float = 0.5 ) -> list [str ]:
22
+ def process_classifications (predictions : list [dict [str , float ]], thresholds : dict [ str , float ] ) -> list [str ]:
23
23
"""
24
- Process the predictions and classify as follows :
25
- 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
26
- 2. Filter classifications based on the threshold , excluding 'Not EJ'
27
- 3. Default to 'Not EJ' if no classifications meet the threshold
24
+ Process the predictions and classify based on the individual thresholds per indicator :
25
+ 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
26
+ 2. Filter classifications based on their individual thresholds , excluding 'Not EJ'.
27
+ 3. Default to 'Not EJ' if no classifications meet the threshold.
28
28
"""
29
29
highest_prediction = max (predictions , key = lambda x : x ["score" ])
30
30
31
31
if highest_prediction ["label" ] == "Not EJ" :
32
32
return ["Not EJ" ]
33
33
34
34
classifications = [
35
- pred ["label" ] for pred in predictions if pred ["score" ] >= threshold and pred ["label" ] != "Not EJ"
35
+ pred ["label" ]
36
+ for pred in predictions
37
+ if pred ["score" ] >= thresholds [pred ["label" ]] and pred ["label" ] != "Not EJ"
36
38
]
37
39
38
40
return classifications if classifications else ["Not EJ" ]
@@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
63
65
def update_cmr_with_classifications (
64
66
inferences : list [dict [str , dict ]],
65
67
cmr_dict : dict [str , dict [str , dict ]],
66
- threshold : float = 0.5 ,
68
+ thresholds : dict [ str , float ] ,
67
69
) -> list [dict [str , dict ]]:
68
70
"""Update CMR data with valid classifications based on inferences."""
69
71
70
72
predicted_cmr = []
71
73
72
74
for inference in inferences :
73
- classifications = process_classifications (predictions = inference ["predictions" ], threshold = threshold )
75
+ classifications = process_classifications (predictions = inference ["predictions" ], thresholds = thresholds )
74
76
classifications = remove_unauthorized_classifications (classifications )
75
77
76
78
if classifications :
@@ -84,17 +86,30 @@ def update_cmr_with_classifications(
84
86
85
87
86
88
def main ():
87
- inferences = load_json_file ("cmr-inference.json" )
89
+ thresholds = {
90
+ "Not EJ" : 0.80 ,
91
+ "Climate Change" : 0.95 ,
92
+ "Disasters" : 0.80 ,
93
+ "Extreme Heat" : 0.50 ,
94
+ "Food Availability" : 0.80 ,
95
+ "Health & Air Quality" : 0.90 ,
96
+ "Human Dimensions" : 0.80 ,
97
+ "Urban Flooding" : 0.50 ,
98
+ "Water Availability" : 0.80 ,
99
+ }
100
+
101
+ inferences = load_json_file ("alpha-1.3-wise-vortex-42-predictions.json" )
88
102
cmr = load_json_file ("cmr_collections_umm_20240807_142146.json" )
89
103
90
104
cmr_dict = create_cmr_dict (cmr )
91
105
92
- predicted_cmr = update_cmr_with_classifications (inferences = inferences , cmr_dict = cmr_dict , threshold = 0.8 )
106
+ predicted_cmr = update_cmr_with_classifications (inferences = inferences , cmr_dict = cmr_dict , thresholds = thresholds )
93
107
94
108
timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
95
109
file_name = f"ej_dump_{ timestamp } .json"
96
110
97
111
save_to_json (predicted_cmr , file_name )
112
+ print (f"Saved to { file_name } " )
98
113
99
114
100
115
if __name__ == "__main__" :
0 commit comments