Skip to content

Commit e5386fa

Browse files
authored
Merge pull request #271 from allenai/favyen/20260121-forest-loss-peru
Forest loss driver: new round of Peru annotation
2 parents c2d47f2 + 9bcd79b commit e5386fa

File tree

12 files changed

+426
-161
lines changed

12 files changed

+426
-161
lines changed

data/forest_loss_driver/config_studio_annotation.json

Lines changed: 56 additions & 134 deletions
Large diffs are not rendered by default.

rslp/forest_loss_driver/scripts/add_area_to_studio_tasks.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from rslearn.utils.geometry import STGeometry
2121
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
2222

23-
BASE_URL = "https://earth-system-studio.allen.ai/api/v1"
23+
BASE_URL = "https://olmoearth.allenai.org/api/v1"
2424

2525
# Arbitrary user ID to save the annotation under.
2626
# This one is ES Studio User.
@@ -37,8 +37,10 @@
3737
# Get the annotation metadata field ID for the Area field.
3838
url = f"{BASE_URL}/projects/{project_id}"
3939
response = requests.get(url, headers=headers, timeout=10)
40-
assert response.status_code == 200
41-
project_data = response.json()
40+
response.raise_for_status()
41+
json_data = response.json()
42+
assert len(json_data["records"]) == 1
43+
project_data = json_data["records"][0]
4244
metadata_field_id = None
4345
for metadata_field in project_data["template"]["annotation_metadata_fields"]:
4446
if metadata_field["name"] != "Area":
@@ -50,13 +52,13 @@
5052
# Now iterate through tasks.
5153
url = f"{BASE_URL}/projects/{project_id}/tasks?limit=1000"
5254
response = requests.get(url, headers=headers, timeout=10)
53-
assert response.status_code == 200
55+
response.raise_for_status()
5456
item_list = response.json()["items"]
5557
for task in tqdm.tqdm(item_list):
5658
task_id = task["id"]
5759
url = f"{BASE_URL}/tasks/{task_id}/annotations"
5860
response = requests.get(url, headers=headers, timeout=10)
59-
assert response.status_code == 200
61+
response.raise_for_status()
6062
fc = response.json()
6163
if len(fc["features"]) != 1:
6264
continue
@@ -106,4 +108,4 @@
106108

107109
url = f"{BASE_URL}/annotations/{annotation_id}"
108110
response = requests.put(url, json.dumps(post_data), headers=headers, timeout=10)
109-
assert response.status_code == 200
111+
response.raise_for_status()
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
This project is for populating examples for new phase of Peru annotation.
2+
3+
## Get Predictions
4+
5+
First we get predictions in Peru for a five-year period. `integrated_config.yaml`
6+
contains the YAML config used for the integrated inference pipeline in
7+
olmoearth_projects:
8+
9+
```
10+
python -m olmoearth_projects.main projects.forest_loss_driver.deploy integrated_pipeline --config ../rslearn_projects/rslp/forest_loss_driver/scripts/peru_20260112/integrated_config.yaml
11+
```
12+
13+
We only need to run it up till it collects the events across the Studio jobs, we got
14+
this file:
15+
16+
```
17+
/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/inference/dataset_20260109/events_from_studio_jobs.geojson
18+
```
19+
20+
## Select Examples
21+
22+
Then we select examples for annotation:
23+
24+
```
25+
python rslp/forest_loss_driver/scripts/peru_20260112/select_examples_for_annotation.py
26+
```
27+
28+
This script will read the events from the file above and write out an rslearn dataset
29+
here:
30+
31+
```
32+
/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/
33+
```
34+
35+
The rslearn dataset should be first created with config file from
36+
`data/forest_loss_driver/config_studio_annotation.json`.
37+
38+
The selection is done by randomly sampling 100 forest loss events that were predicted
39+
as each of logging/burned/none/river/airstrip (500 total), and another 500 where the
40+
maximum probability is <0.5 (indicating the model was not confident).
41+
42+
## Prepare and Materialize
43+
44+
Make sure to set PLANET_API_KEY env var since it is used in the dataset config. Then:
45+
46+
```
47+
rslearn dataset prepare --root /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/ --workers 128 --retry-max-attempts 10 --retry-backoff-seconds 5
48+
rslearn dataset materialize --root /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/ --workers 128 --retry-max-attempts 10 --retry-backoff-seconds 5 --ignore-errors
49+
```
50+
51+
## Additional Steps
52+
53+
Afterwards there are a few additional steps we need to do because we forgot to include
54+
it in the initial example selection script.
55+
56+
First, rename the tasks so they have the format `[#113] 2024-05-13 at -8.9846, -76.7046 prediction:burned`:
57+
58+
```
59+
python rslp/forest_loss_driver/scripts/peru_20260112/rename_tasks.py
60+
```
61+
62+
Then, add the label layer (forest loss polygon):
63+
64+
```
65+
python rslp/forest_loss_driver/scripts/peru_20260112/add_label.py
66+
```
67+
68+
## Sync to Studio
69+
70+
Copy to GCS:
71+
72+
```
73+
gsutil -m rsync -r /weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/ gs://ai2-rslearn-projects-data/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/
74+
```
75+
76+
Then make request to have it import the dataset (need to create project in Studio first):
77+
78+
```
79+
curl https://olmoearth.allenai.org/api/v1/datasets/ingest --request PUT --header 'Content-Type: application/json' --header "Authorization: Bearer $STUDIO_API_TOKEN" --data '{"dataset_path": "gs://ai2-rslearn-projects-data/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/", "project_id": "60e16f40-dbe8-4932-af1b-3f762572530d", "layer_source_names": {}, "prediction_layer_names": []}'
80+
```
81+
82+
After the project is populated, copy the annotation metadata fields from another
83+
project (should have Confidence enum with High/Medium/Low and Area number with 0-9999)
84+
and use `../add_area_to_studio_tasks.py` to set the area in hectares for each polygon.
85+
86+
At 2026-01-20 we sent the project to ACA and they are now looking at it, once
87+
annotation is completed we will need to look into retraining the model.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Add the label polygon since we forgot to include it initially."""
2+
3+
import multiprocessing
4+
from datetime import datetime, timedelta
5+
6+
import tqdm
7+
from rasterio.crs import CRS
8+
from rslearn.dataset import Dataset
9+
from rslearn.utils.feature import Feature
10+
from rslearn.utils.geometry import Projection
11+
from rslearn.utils.grid_index import GridIndex
12+
from rslearn.utils.vector_format import GeojsonCoordinateMode, GeojsonVectorFormat
13+
from upath import UPath
14+
15+
PREDICTION_FNAME = "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/inference/dataset_20260109/events_from_studio_jobs.geojson"
16+
OUTPUT_DATASET_PATH = "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/"
17+
NUM_WORKERS = 128
18+
19+
# Web Mercator projection that all windows are in.
20+
PROJECTION = Projection(CRS.from_epsg(3857), 9.554628535647032, -9.554628535647032)
21+
22+
23+
def reproject_feature(feat: Feature) -> Feature:
24+
"""Helper function to re-project a feature to the WebMercator projection."""
25+
return Feature(feat.geometry.to_projection(PROJECTION), feat.properties)
26+
27+
28+
if __name__ == "__main__":
29+
multiprocessing.set_start_method("forkserver")
30+
31+
# Load features (predictions) and windows.
32+
features = GeojsonVectorFormat().decode_from_file(UPath(PREDICTION_FNAME))
33+
dataset = Dataset(UPath(OUTPUT_DATASET_PATH))
34+
windows = dataset.load_windows(show_progress=True, workers=128)
35+
36+
# We need to find the feature that corresponds to each window so we can add it as
37+
# the label layer. So we create a grid index over the features. We use Web Mercator
38+
# for the grid index since the index needs everything in one projection.
39+
p = multiprocessing.Pool(NUM_WORKERS)
40+
reprojected_features = p.imap_unordered(reproject_feature, features)
41+
grid_index = GridIndex(size=100)
42+
for feat in tqdm.tqdm(
43+
reprojected_features, desc="Creating grid index", total=len(features)
44+
):
45+
grid_index.insert(feat.geometry.shp.bounds, feat)
46+
p.close()
47+
48+
# Now iterate over windows and find the closest feature.
49+
# We make sure that the dates line up.
50+
for window in tqdm.tqdm(windows, desc="Adding labels"):
51+
candidates: list[Feature] = grid_index.query(window.bounds)
52+
best_feat = None
53+
best_distance: int | None = None
54+
for candidate in candidates:
55+
candidate_point = candidate.geometry.to_projection(PROJECTION).shp.centroid
56+
distance = window.get_geometry().shp.centroid.distance(candidate_point)
57+
if best_distance is None or distance < best_distance:
58+
best_feat = candidate
59+
best_distance = distance
60+
61+
# The rslearn windows were created using select_examples_for_annotation.py
62+
# based on the centroid of the GeoJSON featuers, so if there is large distance
63+
# then it must mean we matched to the wrong feature.
64+
if best_feat is None or best_distance is None or best_distance > 10:
65+
raise ValueError(f"no spatially matching feature for window {window.name}")
66+
67+
feat_datetime = datetime.fromisoformat(best_feat.properties["oe_start_time"])
68+
if abs(feat_datetime - window.time_range[0]) > timedelta(days=1):
69+
raise ValueError(f"no tempoarlly matching feature for window {window.name}")
70+
71+
layer_dir = window.get_layer_dir("label")
72+
# Reset the label so it is marked unlabeled.
73+
best_feat.properties["new_label"] = "unlabeled"
74+
GeojsonVectorFormat(coordinate_mode=GeojsonCoordinateMode.WGS84).encode_vector(
75+
layer_dir, [best_feat]
76+
)
77+
window.mark_layer_completed("label")
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
integrated_config:
2+
weka_base_dir: "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/inference/"
3+
gcs_base_dir: "gs://ai2-rslearn-projects-data/forest_loss_driver/dataset_v1/peru_20260112/inference/"
4+
extract_alerts_args:
5+
gcs_tiff_filenames:
6+
- "070W_10S_060W_00N.tif"
7+
- "070W_20S_060W_10S.tif"
8+
- "080W_10S_070W_00N.tif"
9+
- "080W_20S_070W_10S.tif"
10+
out_fname: "placeholder"
11+
country_data_path: "/weka/dfive-default/rslearn-eai/artifacts/natural_earth_countries/20240830/ne_10m_admin_0_countries.shp"
12+
countries: ["PE"]
13+
days: 1825
14+
max_number_of_events: 200000
15+
asset_workers: 128
16+
make_tiles_workers: 128
17+
write_individual_events_workers: 128
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""We initially named the tasks differently so we rename it to better format."""
2+
3+
import random
4+
import shutil
5+
6+
import tqdm
7+
from rslearn.dataset.dataset import Dataset, Window
8+
from upath import UPath
9+
10+
DATASET_PATH = "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/"
11+
12+
13+
if __name__ == "__main__":
14+
ds_path = UPath(DATASET_PATH)
15+
dataset = Dataset(ds_path)
16+
windows = dataset.load_windows()
17+
random.shuffle(windows)
18+
for idx, window in enumerate(tqdm.tqdm(windows)):
19+
src_name = window.name
20+
_, lon_str, lat_str, predicted_category = src_name.split("_")
21+
date_time_str = window.time_range[0].strftime("%Y-%m-%d")
22+
dst_name = f"[#{idx+1:04d}] {date_time_str} at {float(lat_str):.04f}, {float(lon_str):.04f} prediction:{predicted_category}"
23+
shutil.move(
24+
Window.get_window_root(ds_path, window.group, src_name),
25+
Window.get_window_root(ds_path, window.group, dst_name),
26+
)
27+
window.name = dst_name
28+
window.save()
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Select examples for this new Peru annotation.
2+
3+
Based on predictions in Peru over five-year period:
4+
- Select 100 for each of logging/burned/none/river/airstrip
5+
- Select 500 from other categories where max(probs) < 0.5
6+
"""
7+
8+
import random
9+
from datetime import datetime
10+
11+
from rasterio.crs import CRS
12+
from rslearn.const import WGS84_PROJECTION
13+
from rslearn.dataset import Dataset, Window
14+
from rslearn.utils.feature import Feature
15+
from rslearn.utils.geometry import Projection
16+
from rslearn.utils.grid_index import GridIndex
17+
from rslearn.utils.vector_format import GeojsonVectorFormat
18+
from upath import UPath
19+
20+
PREDICTION_FNAME = "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/inference/dataset_20260109/events_from_studio_jobs.geojson"
21+
OUTPUT_DATASET_PATH = "/weka/dfive-default/rslearn-eai/datasets/forest_loss_driver/dataset_v1/peru_20260112/rslearn_dataset_for_selected_events/"
22+
TARGET_GROUP = "20260112_peru"
23+
RARE_CATEGORIES = ["logging", "burned", "none", "river", "airstrip"]
24+
PROB_THRESHOLD = 0.5
25+
DISTANCE_THRESHOLD = 1000 / 111111
26+
WINDOW_SIZE = 128
27+
28+
29+
if __name__ == "__main__":
30+
# Load predictions.
31+
predictions = GeojsonVectorFormat().decode_from_file(UPath(PREDICTION_FNAME))
32+
33+
# Create candidates for the different selection criteria.
34+
by_class_options: dict[str, list[Feature]] = {
35+
category: [] for category in RARE_CATEGORIES
36+
}
37+
by_prob_options: list[Feature] = []
38+
for feat in predictions:
39+
category = feat.properties["new_label"]
40+
if category in RARE_CATEGORIES:
41+
by_class_options[category].append(feat)
42+
elif max(feat.properties["probs"]) < PROB_THRESHOLD:
43+
by_prob_options.append(feat)
44+
45+
for category, candidates in by_class_options.items():
46+
print(f"got {len(candidates)} options by class for category={category}")
47+
print(f"got {len(by_prob_options)} options by prob")
48+
49+
# Select windows, we make sure their center points are at least 500 m away from
50+
# each other.
51+
grid_index = GridIndex(size=DISTANCE_THRESHOLD)
52+
selected: list[Feature] = []
53+
54+
def contains_bbox(box: tuple[float, float, float, float]) -> bool:
55+
"""Check whether the box intersects a point in grid_index."""
56+
for other in grid_index.query(box):
57+
if (
58+
other[0] > box[0]
59+
and other[1] > box[1]
60+
and other[0] < box[2]
61+
and other[1] < box[3]
62+
):
63+
return True
64+
return False
65+
66+
def add_random_sample_of_features(features: list[Feature], max_count: int) -> int:
67+
"""Add a random sample of windows from the list to the selected set."""
68+
# Add up to max_count from the features list.
69+
random.shuffle(features)
70+
cur_selected: list[Feature] = []
71+
for feat in features:
72+
center_point = feat.geometry.to_projection(WGS84_PROJECTION).shp.centroid
73+
if contains_bbox(
74+
(
75+
center_point.x - DISTANCE_THRESHOLD,
76+
center_point.y - DISTANCE_THRESHOLD,
77+
center_point.x + DISTANCE_THRESHOLD,
78+
center_point.y + DISTANCE_THRESHOLD,
79+
)
80+
):
81+
continue
82+
83+
cur_selected.append(feat)
84+
grid_index.insert(
85+
(center_point.x, center_point.y, center_point.x, center_point.y),
86+
(center_point.x, center_point.y),
87+
)
88+
if len(cur_selected) >= max_count:
89+
break
90+
91+
selected.extend(cur_selected)
92+
return len(cur_selected)
93+
94+
for category, candidates in by_class_options.items():
95+
count = add_random_sample_of_features(candidates, 100)
96+
print(f"by class category={category} picked {count}/{len(candidates)} windows")
97+
count = add_random_sample_of_features(by_prob_options, 500)
98+
print(f"by prob picked {count}/{len(by_prob_options)} windows")
99+
print(f"got {len(selected)} total to remap")
100+
101+
# Create windows in the destination dataset for these features.
102+
dataset = Dataset(UPath(OUTPUT_DATASET_PATH))
103+
dst_proj = Projection(CRS.from_epsg(3857), 9.554628535647032, -9.554628535647032)
104+
random.shuffle(selected)
105+
for idx, feat in enumerate(selected):
106+
wgs84_geom = feat.geometry.to_projection(WGS84_PROJECTION)
107+
lon = wgs84_geom.shp.centroid.x
108+
lat = wgs84_geom.shp.centroid.y
109+
predicted_category = feat.properties["new_label"]
110+
window_name = f"[#{idx}]_{lon:.04f}_{lat:.04f}_predicted:{predicted_category}"
111+
112+
# Get bounds in our WebMercator projection.
113+
dst_geom = feat.geometry.to_projection(dst_proj)
114+
dst_bounds = (
115+
int(dst_geom.shp.centroid.x) - WINDOW_SIZE // 2,
116+
int(dst_geom.shp.centroid.y) - WINDOW_SIZE // 2,
117+
int(dst_geom.shp.centroid.x) + WINDOW_SIZE // 2,
118+
int(dst_geom.shp.centroid.y) + WINDOW_SIZE // 2,
119+
)
120+
121+
ts = datetime.fromisoformat(feat.properties["oe_start_time"])
122+
window = Window(
123+
storage=dataset.storage,
124+
group=TARGET_GROUP,
125+
name=window_name,
126+
projection=dst_proj,
127+
bounds=dst_bounds,
128+
time_range=(ts, ts),
129+
)
130+
window.save()

0 commit comments

Comments
 (0)