Skip to content

Commit 1df2a24

Browse files
authored
Merge pull request #936 from NASA-IMPACT/923-ej-classification-processing-and-ingest-scripts
923 ej classification processing and ingest scripts
2 parents c444ae5 + 9a62d3d commit 1df2a24

File tree

3 files changed

+208
-0
lines changed

3 files changed

+208
-0
lines changed

scripts/ej/cmr_to_models.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
the ej_dump is generated by running create_ej_dump.py and is scp'd to the COSMOS server
3+
this script is then run via the dm shell on the COSMOS server to populate the database
4+
"""
5+
6+
import json
7+
8+
from environmental_justice.models import EnvironmentalJusticeRow
9+
10+
import urllib.parse
11+
12+
13+
def generate_source_link(doi_field):
14+
authority = doi_field.get("Authority")
15+
doi = doi_field.get("DOI")
16+
if authority and doi:
17+
return urllib.parse.urljoin(authority, doi)
18+
return ""
19+
20+
21+
def concept_id_to_sinequa_id(concept_id: str) -> str:
22+
return f"/SDE/CMR_API/|{concept_id}"
23+
24+
25+
def sinequa_id_to_url(sinequa_id: str) -> str:
26+
base_url = "https://sciencediscoveryengine.nasa.gov/app/nasa-sba-smd/#/preview"
27+
query = '{"name":"query-smd-primary","scope":"All","text":""}'
28+
29+
encoded_id = urllib.parse.quote(sinequa_id, safe="")
30+
encoded_query = urllib.parse.quote(query, safe="")
31+
32+
return f"{base_url}?id={encoded_id}&query={encoded_query}"
33+
34+
35+
def categorize_processing_level(level):
36+
37+
advanced_analysis_levels = {"0", "Level 0", "NA", "Not Provided", "Not provided"}
38+
39+
basic_analysis_levels = {
40+
"1",
41+
"1A",
42+
"1B",
43+
"1C",
44+
"1T",
45+
"2",
46+
"2A",
47+
"2B",
48+
"2G",
49+
"2P",
50+
"Level 1",
51+
"Level 1A",
52+
"Level 1B",
53+
"Level 1C",
54+
"Level 2",
55+
"Level 2A",
56+
"Level 2B",
57+
}
58+
59+
exploration_levels = {"3", "4", "Level 3", "Level 4", "L2"}
60+
61+
if level in exploration_levels:
62+
return "exploration"
63+
elif level in basic_analysis_levels:
64+
return "basic analysis"
65+
elif level in advanced_analysis_levels:
66+
return "advanced analysis"
67+
else:
68+
return "advanced analysis"
69+
70+
71+
# remove existing data
72+
EnvironmentalJusticeRow.objects.filter(destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV).delete()
73+
74+
ej_dump = json.load(open("backups/ej_dump_20240815_112916.json"))
75+
for dataset in ej_dump:
76+
ej_row = EnvironmentalJusticeRow(
77+
destination_server=EnvironmentalJusticeRow.DestinationServerChoices.DEV,
78+
sde_link=sinequa_id_to_url(concept_id_to_sinequa_id(dataset.get("meta", {}).get("concept-id", ""))),
79+
dataset=dataset.get("umm", {}).get("ShortName", ""),
80+
description=dataset.get("umm", {}).get("Abstract", ""),
81+
limitations=dataset.get("umm", {}).get("AccessConstraints", {}).get("Description", ""),
82+
format=dataset.get("meta", {}).get("format", ""),
83+
temporal_extent=", ".join(dataset.get("umm", {}).get("TemporalExtents", [{}])[0].get("SingleDateTimes", [])),
84+
intended_use=categorize_processing_level(
85+
dataset.get("umm", {}).get("ProcessingLevel", {}).get("Id", "advanced analysis")
86+
),
87+
source_link=generate_source_link(dataset.get("umm", {}).get("DOI", {})),
88+
indicators=dataset["indicators"],
89+
geographic_coverage="", # Not provided in the data
90+
data_visualization="", # dataset.get("umm", {}).get("RelatedUrls", [{}])[0].get("URL", ""),
91+
latency="", # Not provided in the data
92+
spatial_resolution="", # Not provided in the data
93+
temporal_resolution="", # Not provided in the data
94+
description_simplified="", # Not provided in the data
95+
project="", # Not provided in the data
96+
strengths="", # Not provided in the data
97+
)
98+
ej_row.save()

scripts/ej/create_ej_dump.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
inferences are supplied by the classification model. the contact point is Bishwas
3+
cmr is supplied by running https://github.com/NASA-IMPACT/llm-app-EJ-classifier/blob/develop/scripts/data_processing/download_cmr.py
4+
move to the serve like this: scp ej_dump_20240814_143036.json sde:/home/ec2-user/sde_indexing_helper/backups/
5+
"""
6+
7+
import json
8+
from datetime import datetime
9+
10+
11+
def load_json_file(file_path: str) -> dict:
12+
with open(file_path, "r") as file:
13+
return json.load(file)
14+
15+
16+
def save_to_json(data: dict | list, file_path: str) -> None:
17+
with open(file_path, "w") as file:
18+
json.dump(data, file, indent=2)
19+
20+
21+
def process_classifications(predictions: list[dict[str, float]], threshold: float = 0.5) -> list[str]:
22+
"""
23+
Process the predictions and classify as follows:
24+
1. If 'Not EJ' is the highest scoring prediction, return 'Not EJ' as the only classification
25+
2. Filter classifications based on the threshold, excluding 'Not EJ'
26+
3. Default to 'Not EJ' if no classifications meet the threshold
27+
"""
28+
highest_prediction = max(predictions, key=lambda x: x["score"])
29+
30+
if highest_prediction["label"] == "Not EJ":
31+
return ["Not EJ"]
32+
33+
classifications = [
34+
pred["label"] for pred in predictions if pred["score"] >= threshold and pred["label"] != "Not EJ"
35+
]
36+
37+
return classifications if classifications else ["Not EJ"]
38+
39+
40+
def create_cmr_dict(cmr_data: list[dict[str, dict[str, str]]]) -> dict[str, dict[str, dict[str, str]]]:
41+
"""Restructure CMR data into a dictionary with 'concept-id' as the key."""
42+
return {dataset["meta"]["concept-id"]: dataset for dataset in cmr_data}
43+
44+
45+
def remove_unauthorized_classifications(classifications: list[str]) -> list[str]:
46+
"""Filter classifications to keep only those in the authorized list."""
47+
48+
authorized_classifications = [
49+
"Climate Change",
50+
"Disasters",
51+
"Extreme Heat",
52+
"Food Availability",
53+
"Health & Air Quality",
54+
"Human Dimensions",
55+
"Urban Flooding",
56+
"Water Availability",
57+
]
58+
59+
return [cls for cls in classifications if cls in authorized_classifications]
60+
61+
62+
def update_cmr_with_classifications(
63+
inferences: list[dict[str, dict]],
64+
cmr_dict: dict[str, dict[str, dict]],
65+
threshold: float = 0.5,
66+
) -> list[dict[str, dict]]:
67+
"""Update CMR data with valid classifications based on inferences."""
68+
69+
predicted_cmr = []
70+
71+
for inference in inferences:
72+
classifications = process_classifications(predictions=inference["predictions"], threshold=threshold)
73+
classifications = remove_unauthorized_classifications(classifications)
74+
75+
if classifications:
76+
cmr_dataset = cmr_dict.get(inference["concept-id"])
77+
78+
if cmr_dataset:
79+
cmr_dataset["indicators"] = ";".join(classifications)
80+
predicted_cmr.append(cmr_dataset)
81+
82+
return predicted_cmr
83+
84+
85+
def main():
86+
inferences = load_json_file("cmr-inference.json")
87+
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")
88+
89+
cmr_dict = create_cmr_dict(cmr)
90+
91+
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=0.8)
92+
93+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
94+
file_name = f"ej_dump_{timestamp}.json"
95+
96+
save_to_json(predicted_cmr, file_name)
97+
98+
99+
if __name__ == "__main__":
100+
main()

scripts/ej/thresholding.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from create_ej_dump import load_json_file, create_cmr_dict, update_cmr_with_classifications
2+
3+
inferences = load_json_file("cmr-inference.json")
4+
cmr = load_json_file("cmr_collections_umm_20240807_142146.json")
5+
6+
cmr_dict = create_cmr_dict(cmr)
7+
8+
for threshold in [0.5, 0.6, 0.7, 0.8, 0.9]:
9+
predicted_cmr = update_cmr_with_classifications(inferences=inferences, cmr_dict=cmr_dict, threshold=threshold)
10+
print(f"Threshold: {int(threshold*100)}%, EJ datasets: {len(predicted_cmr)}")

0 commit comments

Comments
 (0)