22inferences are supplied by the classification model. the contact point is Bishwas
33cmr is supplied by running
44github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
5- move to the serve like this: scp ej_dump_20240814_143036 .json sde:/home/ec2-user/sde_indexing_helper/backups/
5+ move to the server like this: scp ej_dump_20241017_133151 .json sde:/home/ec2-user/sde_indexing_helper/backups/
66"""
77
88import json
@@ -19,20 +19,22 @@ def save_to_json(data: dict | list, file_path: str) -> None:
1919 json .dump (data , file , indent = 2 )
2020
2121
22- def process_classifications (predictions : list [dict [str , float ]], threshold : float = 0.5 ) -> list [str ]:
22+ def process_classifications (predictions : list [dict [str , float ]], thresholds : dict [ str , float ] ) -> list [str ]:
2323 """
24- Process the predictions and classify as follows :
25- 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
26- 2. Filter classifications based on the threshold , excluding 'Not EJ'
27- 3. Default to 'Not EJ' if no classifications meet the threshold
24+ Process the predictions and classify based on the individual thresholds per indicator :
25+ 1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification.
26+ 2. Filter classifications based on their individual thresholds , excluding 'Not EJ'.
27+ 3. Default to 'Not EJ' if no classifications meet the threshold.
2828 """
2929 highest_prediction = max (predictions , key = lambda x : x ["score" ])
3030
3131 if highest_prediction ["label" ] == "Not EJ" :
3232 return ["Not EJ" ]
3333
3434 classifications = [
35- pred ["label" ] for pred in predictions if pred ["score" ] >= threshold and pred ["label" ] != "Not EJ"
35+ pred ["label" ]
36+ for pred in predictions
37+ if pred ["score" ] >= thresholds [pred ["label" ]] and pred ["label" ] != "Not EJ"
3638 ]
3739
3840 return classifications if classifications else ["Not EJ" ]
@@ -63,14 +65,14 @@ def remove_unauthorized_classifications(classifications: list[str]) -> list[str]
6365def update_cmr_with_classifications (
6466 inferences : list [dict [str , dict ]],
6567 cmr_dict : dict [str , dict [str , dict ]],
66- threshold : float = 0.5 ,
68+ thresholds : dict [ str , float ] ,
6769) -> list [dict [str , dict ]]:
6870 """Update CMR data with valid classifications based on inferences."""
6971
7072 predicted_cmr = []
7173
7274 for inference in inferences :
73- classifications = process_classifications (predictions = inference ["predictions" ], threshold = threshold )
75+ classifications = process_classifications (predictions = inference ["predictions" ], thresholds = thresholds )
7476 classifications = remove_unauthorized_classifications (classifications )
7577
7678 if classifications :
@@ -84,17 +86,30 @@ def update_cmr_with_classifications(
8486
8587
8688def main ():
87- inferences = load_json_file ("cmr-inference.json" )
89+ thresholds = {
90+ "Not EJ" : 0.80 ,
91+ "Climate Change" : 0.95 ,
92+ "Disasters" : 0.80 ,
93+ "Extreme Heat" : 0.50 ,
94+ "Food Availability" : 0.80 ,
95+ "Health & Air Quality" : 0.90 ,
96+ "Human Dimensions" : 0.80 ,
97+ "Urban Flooding" : 0.50 ,
98+ "Water Availability" : 0.80 ,
99+ }
100+
101+ inferences = load_json_file ("alpha-1.3-wise-vortex-42-predictions.json" )
88102 cmr = load_json_file ("cmr_collections_umm_20240807_142146.json" )
89103
90104 cmr_dict = create_cmr_dict (cmr )
91105
92- predicted_cmr = update_cmr_with_classifications (inferences = inferences , cmr_dict = cmr_dict , threshold = 0.8 )
106+ predicted_cmr = update_cmr_with_classifications (inferences = inferences , cmr_dict = cmr_dict , thresholds = thresholds )
93107
94108 timestamp = datetime .now ().strftime ("%Y%m%d_%H%M%S" )
95109 file_name = f"ej_dump_{ timestamp } .json"
96110
97111 save_to_json (predicted_cmr , file_name )
112+ print (f"Saved to { file_name } " )
98113
99114
100115if __name__ == "__main__" :
0 commit comments