Skip to content

Commit e700d77

Browse files
committed
Add geojson task and unit tests
1 parent a609176 commit e700d77

File tree

8 files changed

+413
-1
lines changed

8 files changed

+413
-1
lines changed

functions-python/helpers/locations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from sqlalchemy import func, cast
77
from geoalchemy2.types import Geography
88

9-
import pycountry
109
from shared.database_gen.sqlacodegen_models import Feed, Location, Geopolygon
1110
import logging
1211

@@ -35,6 +34,7 @@ def get_country_code(country_name: str) -> Optional[str]:
3534
Returns:
3635
Optional[str]: Two-letter ISO country code or None if not found
3736
"""
37+
import pycountry
3838
# Return None for empty or whitespace-only strings
3939
if not country_name or not country_name.strip():
4040
logging.error("Could not find country code for: empty string")

functions-python/tasks_executor/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pluggy~=1.3.0
1111
certifi~=2025.8.3
1212
fastapi
1313
uvicorn[standard]
14+
psutil
1415

1516

1617
# SQL Alchemy and Geo Alchemy

functions-python/tasks_executor/src/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from tasks.missing_bounding_boxes.rebuild_missing_bounding_boxes import (
3232
rebuild_missing_bounding_boxes_handler,
3333
)
34+
from tasks.geojson.update_geojson_files_precision import (
35+
update_geojson_files_precision_handler,
36+
)
3437

3538
init_logger()
3639
LIST_COMMAND: Final[str] = "list"
@@ -62,6 +65,10 @@
6265
"description": "Rebuilds missing dataset files for GTFS datasets.",
6366
"handler": rebuild_missing_dataset_files_handler,
6467
},
68+
"update_geojson_files": {
69+
"description": "Iterate over bucket looking for {feed_stable_id}/geolocation.geojson and update precision.",
70+
"handler": update_geojson_files_precision_handler,
71+
},
6572
}
6673

6774

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
#
2+
# MobilityData 2025
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import json
18+
import logging
19+
import os
20+
import re
21+
from typing import Any, Dict, List
22+
23+
from shared.database.database import with_db_session
24+
from shared.database_gen.sqlacodegen_models import Gtfsfeed
25+
from shared.helpers.locations import round_geojson_coords
26+
from shared.helpers.runtime_metrics import track_metrics
27+
28+
GEOLOCATION_FILENAME = "geolocation.geojson"
29+
30+
31+
def _remove_osm_ids_from_properties(props: Dict[str, Any] | None):
32+
"""Remove OSM identifier-like properties from a feature's properties in-place."""
33+
if not props or not isinstance(props, dict):
34+
return
35+
keys_to_remove = []
36+
for k, v in list(props.items()):
37+
lk = k.lower()
38+
if "osm" in lk or lk == "@id":
39+
keys_to_remove.append(k)
40+
elif lk == "id":
41+
if isinstance(v, str) and re.search(
42+
r"\b(node|way|relation)\b|/", v, re.IGNORECASE
43+
):
44+
keys_to_remove.append(k)
45+
for k in keys_to_remove:
46+
props.pop(k, None)
47+
48+
49+
def query_unprocessed_feeds(limit, db_session):
50+
"""
51+
Query Gtfsfeed entries that have not been processed yet, geolocation_file_created_date is null.
52+
"""
53+
feeds = (
54+
db_session.query(Gtfsfeed)
55+
.filter(Gtfsfeed.geolocation_file_created_date.is_(None))
56+
.limit(limit)
57+
.all()
58+
)
59+
return feeds
60+
61+
62+
@track_metrics(metrics=("time", "memory", "cpu"))
63+
def _upload_file(bucket, geojson):
64+
processed_blob = bucket.blob("geolocation.geojson")
65+
processed_blob.upload_from_string(
66+
json.dumps(geojson, ensure_ascii=False),
67+
content_type="application/geo+json",
68+
)
69+
70+
71+
@track_metrics(metrics=("time", "memory", "cpu"))
72+
def _update_feed_info(feed, timestamp):
73+
feed.geolocation_file_created_date = timestamp
74+
# find the most recent dataset with bounding box and set the id
75+
if feed.gtfsdatasets and any(d.bounding_box for d in feed.gtfsdatasets):
76+
latest_with_bbox = max(
77+
(d for d in feed.gtfsdatasets if d.bounding_box),
78+
key=lambda d: d.downloaded_date or timestamp,
79+
)
80+
feed.geolocation_file_dataset_id = latest_with_bbox.bounding_box.id
81+
82+
83+
@track_metrics(metrics=("time", "memory", "cpu"))
84+
def process_geojson(geopjson, precision):
85+
# Normalize GeoJSON structure to FeatureCollection-like list of features
86+
if isinstance(geopjson, dict) and geopjson.get("type") == "FeatureCollection":
87+
features = geopjson.get("features", [])
88+
elif isinstance(geopjson, dict) and geopjson.get("type") == "Feature":
89+
features = [geopjson]
90+
elif isinstance(geopjson, list):
91+
features = geopjson
92+
else:
93+
# Unknown structure, skip
94+
return
95+
# Apply rounding via shared helper and remove osm ids
96+
for f in features:
97+
if not isinstance(f, dict):
98+
continue
99+
geom = f.get("geometry")
100+
if geom:
101+
# round_geojson_coords returns a new geometry object
102+
try:
103+
f["geometry"] = round_geojson_coords(geom, precision=precision)
104+
except Exception as e:
105+
logging.warning("Error processing feature %s: %s", f.get("name"), e)
106+
return
107+
props = f.get("properties")
108+
_remove_osm_ids_from_properties(props)
109+
# If original was a FeatureCollection, update it; if single Feature, keep as-is; if list, use list
110+
if isinstance(geopjson, dict) and geopjson.get("type") == "FeatureCollection":
111+
geopjson["features"] = features
112+
elif isinstance(geopjson, dict) and geopjson.get("type") == "Feature":
113+
geopjson = features[0] if features else geopjson
114+
else:
115+
geopjson = features
116+
return geopjson
117+
118+
119+
@with_db_session
120+
def update_geojson_files_precision_handler(
121+
payload: Dict[str, Any], db_session
122+
) -> Dict[str, Any]:
123+
"""
124+
Update GeoJSON files in GCS to reduce coordinate precision and remove map ids.
125+
126+
Payload keys:
127+
- dry_run (bool) default True
128+
- precision (int) default 5
129+
- limit (int)
130+
131+
"""
132+
# Import GCS client at runtime to avoid dev environment import issues
133+
try:
134+
from google.cloud import storage
135+
except Exception as e:
136+
raise RuntimeError("google-cloud-storage is required at runtime: %s" % e)
137+
bucket_name = payload.get("bucket_name") or os.getenv("DATASETS_BUCKET_NAME")
138+
if not bucket_name:
139+
raise ValueError(
140+
"bucket_name must be provided in payload or set in GEOJSON_BUCKET env"
141+
)
142+
143+
dry_run = payload.get("dry_run", True)
144+
precision = int(payload.get("precision", 5))
145+
limit = int(payload.get("limit", None))
146+
client = storage.Client()
147+
bucket = client.bucket(bucket_name)
148+
149+
errors: List[Dict[str, str]] = []
150+
processed = 0
151+
152+
feeds: [Gtfsfeed] = query_unprocessed_feeds(limit, db_session)
153+
logging.info("Found %s feeds", len(feeds))
154+
timestamp = db_session.execute("SELECT CURRENT_TIMESTAMP").scalar()
155+
for feed in feeds:
156+
try:
157+
if processed % 100 == 0:
158+
logging.info("Processed %s/%s", processed, len(feeds))
159+
db_session.commit()
160+
with storage.Blob(
161+
bucket=bucket, name=f"{feed.stable_id}/{GEOLOCATION_FILENAME}"
162+
) as file:
163+
if not file.exists():
164+
logging.info("File does not exist: %s", file.name)
165+
continue
166+
logging.info("Processing file: %s", file.name)
167+
text = file.download_as_text()
168+
geojson = json.loads(text)
169+
170+
geojson = process_geojson(geojson, precision)
171+
if not geojson:
172+
logging.info("No valid GeoJSON features found in %s", file.name)
173+
174+
# Optionally upload processed geojson
175+
if not dry_run:
176+
_upload_file(bucket, geojson)
177+
_update_feed_info(feed, timestamp)
178+
179+
processed += 1
180+
except Exception as e:
181+
logging.exception("Error processing feed %s: %s", feed.stable_id, e)
182+
errors.append(feed.stable_id)
183+
184+
summary = {
185+
"total_processed_files": processed,
186+
"errors": errors,
187+
"not_found_file": len(feeds) - processed - len(errors),
188+
"params": {
189+
"dry_run": dry_run,
190+
"precision": precision,
191+
"limit": limit,
192+
},
193+
}
194+
logging.info("update_geojson_files_handler result: %s", summary)
195+
return

functions-python/tasks_executor/tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
# limitations under the License.
1515
#
1616

17+
from pathlib import Path
18+
import sys
19+
20+
# Ensure project's src is on sys.path so imports like 'shared' resolve when tests start
21+
_repo_root = Path(__file__).resolve().parents[1]
22+
_src_dir = str(_repo_root / "src")
23+
if _src_dir not in sys.path:
24+
sys.path.insert(0, _src_dir)
25+
1726
from datetime import datetime, UTC, timedelta
1827

1928
from sqlalchemy.orm import Session

0 commit comments

Comments
 (0)