Skip to content

Commit 3b6d4b2

Browse files
committed
implement gbfs geolocation adjustment
1 parent 800a462 commit 3b6d4b2

File tree

6 files changed

+125
-20
lines changed

6 files changed

+125
-20
lines changed

functions-python/tasks_executor/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ To update the geolocation files precision:
5656
"task": "update_geojson_files_precision",
5757
"payload": {
5858
"dry_run": true,
59+
"data_type": "gtfs",
5960
"precision": 5,
6061
"limit": 10
6162
}

functions-python/tasks_executor/src/tasks/geojson/README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ The function accepts the following payload:
1818
{
1919
"dry_run": true, // [optional] If true, do not upload or modify the database (default: true)
2020
"precision": 5, // [optional] Number of decimal places to keep in coordinates (default: 5)
21-
"limit": 10 // [optional] Limit the number of feeds to process (default: no limit)
21+
"limit": 10, // [optional] Limit the number of feeds to process (default: no limit)
22+
"data_type": "gtfs" // [optional] Type of data to process, either "gtfs" or "gbfs" (default: "gtfs")
2223
}
2324
```
2425

@@ -27,6 +28,7 @@ The function accepts the following payload:
2728
```json
2829
{
2930
"dry_run": true,
31+
"data_type": "gtfs",
3032
"limit": 10
3133
}
3234
```
@@ -42,9 +44,10 @@ Also updates the `geolocation_file_created_date` and `geolocation_file_dataset_i
4244

4345
The function requires the following environment variables:
4446

45-
| Variable | Description |
46-
| ---------------------- | ---------------------------------------------------------------------------- |
47-
| `DATASETS_BUCKET_NAME` | The name of the GCS bucket used to store extracted GTFS files |
47+
| Variable | Description |
48+
|--------------------------------|-------------------------------------------------------------------------|
49+
| `DATASETS_BUCKET_NAME` | The name of the GCS bucket used to store extracted GTFS files |
50+
| `GBFS_SNAPSHOTS_BUCKET_NAME` | The name of the GCS bucket used to store extracted GBFS snapshots files |
4851

4952
---
5053

functions-python/tasks_executor/src/tasks/geojson/update_geojson_files_precision.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import logging
1919
import os
2020
import re
21-
from typing import Any, Dict, List
21+
from typing import Any, Dict, List, Literal
2222

2323
from sqlalchemy import select, func
24+
from sqlalchemy.orm import Session
2425

2526
from shared.database.database import with_db_session
26-
from shared.database_gen.sqlacodegen_models import Gtfsfeed
27+
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gbfsfeed
2728
from shared.helpers.locations import round_geojson_coords
2829
from shared.helpers.runtime_metrics import track_metrics
2930

@@ -48,13 +49,16 @@ def _remove_osm_ids_from_properties(props: Dict[str, Any] | None):
4849
props.pop(k, None)
4950

5051

51-
def query_unprocessed_feeds(limit, db_session):
52+
def query_unprocessed_feeds(
53+
limit: int, feed_type: Literal["gtfs", "gbfs"], db_session: Session
54+
) -> List[Gtfsfeed] | List[Gbfsfeed]:
5255
"""
53-
Query Gtfsfeed entries that have not been processed yet, geolocation_file_created_date is null.
56+
Query eed entries that have not been processed yet, geolocation_file_created_date is null.
5457
"""
58+
model: Gtfsfeed | Gbfsfeed = Gtfsfeed if feed_type == "gtfs" else Gbfsfeed
5559
feeds = (
56-
db_session.query(Gtfsfeed)
57-
.filter(Gtfsfeed.geolocation_file_created_date.is_(None))
60+
db_session.query(model)
61+
.filter(model.geolocation_file_created_date.is_(None))
5862
.limit(limit)
5963
.all()
6064
)
@@ -72,8 +76,10 @@ def _upload_file(bucket, file_path, geojson):
7276

7377

7478
@track_metrics(metrics=("time", "memory", "cpu"))
75-
def _update_feed_info(feed: Gtfsfeed, timestamp):
79+
def _update_feed_info(feed: Gtfsfeed | Gbfsfeed, timestamp):
7680
feed.geolocation_file_created_date = timestamp
81+
if isinstance(feed, Gbfsfeed):
82+
return
7783
# find the most recent dataset with bounding box and set the id
7884
if feed.gtfsdatasets and any(d.bounding_box for d in feed.gtfsdatasets):
7985
latest_with_bbox = max(
@@ -141,29 +147,37 @@ def update_geojson_files_precision_handler(
141147
from google.cloud import storage
142148
except Exception as e:
143149
raise RuntimeError("google-cloud-storage is required at runtime: %s" % e)
144-
bucket_name = payload.get("bucket_name") or os.getenv("DATASETS_BUCKET_NAME")
145-
if not bucket_name:
146-
raise ValueError(
147-
"bucket_name must be provided in payload or set in GEOJSON_BUCKET env"
148-
)
149150

150151
dry_run = payload.get("dry_run", True)
152+
data_type: Literal["gtfs", "gbfs"] = payload.get("data_type", "gtfs")
151153
precision = int(payload.get("precision", 5))
152154
limit = int(payload.get("limit", None))
155+
bucket_name = payload.get("bucket_name") or (
156+
os.getenv("DATASETS_BUCKET_NAME")
157+
if data_type == "gtfs"
158+
else os.getenv("GBFS_SNAPSHOTS_BUCKET_NAME")
159+
)
160+
if not bucket_name:
161+
raise ValueError(
162+
"bucket_name must be provided in payload or set in GEOJSON_BUCKET env"
163+
)
153164
client = storage.Client()
154165
bucket = client.bucket(bucket_name)
155166

156167
errors: List[Dict[str, str]] = []
157168
processed = 0
158169

159-
feeds: [Gtfsfeed] = query_unprocessed_feeds(limit, db_session)
170+
feeds: List[Gtfsfeed] | List[Gbfsfeed] = query_unprocessed_feeds(
171+
limit, data_type, db_session
172+
)
160173
logging.info("Found %s feeds", len(feeds))
161174
timestamp = db_session.execute(select(func.current_timestamp())).scalar()
162175
for feed in feeds:
163176
try:
164177
if processed % 100 == 0:
165178
logging.info("Processed %s/%s", processed, len(feeds))
166-
db_session.commit()
179+
if not dry_run and processed > 0:
180+
db_session.commit()
167181
file_path = f"{feed.stable_id}/{GEOLOCATION_FILENAME}"
168182
file = storage.Blob(bucket=bucket, name=file_path)
169183
if not file.exists():
@@ -189,6 +203,7 @@ def update_geojson_files_precision_handler(
189203
except Exception as e:
190204
logging.exception("Error processing feed %s: %s", feed.stable_id, e)
191205
errors.append(feed.stable_id)
206+
logging.info("Processed %s/%s", processed, len(feeds))
192207
if not dry_run and processed > 0:
193208
db_session.commit()
194209
summary = {

functions-python/tasks_executor/tests/conftest.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
16+
import uuid
1717
from datetime import datetime, UTC, timedelta
1818

1919
from sqlalchemy.orm import Session
@@ -22,6 +22,7 @@
2222
from shared.database_gen.sqlacodegen_models import (
2323
Gtfsfeed,
2424
Gtfsdataset,
25+
Gbfsfeed,
2526
)
2627
from test_shared.test_utils.database_utils import clean_testing_db, default_db_url
2728

@@ -47,6 +48,14 @@ def populate_database(db_session: Session | None = None):
4748
)
4849
db_session.add(feed)
4950
feeds.append(feed)
51+
gbfs_feed = Gbfsfeed(
52+
id=f"feed_{uuid.uuid4()}",
53+
stable_id=f"stable_feed_gbfs_{uuid.uuid4()}",
54+
data_type="gbfs",
55+
status="active",
56+
created_at=now,
57+
)
58+
db_session.add(gbfs_feed)
5059
db_session.flush()
5160

5261
datasets = []

functions-python/tasks_executor/tests/tasks/geojson/test_update_geojson_files_precision.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlalchemy.orm import Session
99

1010
from shared.database.database import with_db_session
11+
from shared.database_gen.sqlacodegen_models import Gbfsfeed
1112
from shared.helpers.src.shared.database_gen.sqlacodegen_models import Gtfsfeed
1213
from tasks.geojson.update_geojson_files_precision import (
1314
process_geojson,
@@ -137,7 +138,7 @@ def test_process_geojson_single_feature_and_list_variants(self):
137138
)
138139

139140
@with_db_session(db_url=default_db_url)
140-
def test_handler_uploads_and_updates_feed_info(self, db_session: Session):
141+
def test_handler_uploads_and_updates_gtfs_feed_info(self, db_session: Session):
141142
geo = {
142143
"type": "FeatureCollection",
143144
"features": [
@@ -209,6 +210,81 @@ def test_handler_uploads_and_updates_feed_info(self, db_session: Session):
209210
self.assertIsNotNone(reloaded_testing_feed.geolocation_file_dataset_id)
210211
self.assertIsNotNone(reloaded_testing_feed.geolocation_file_created_date)
211212

213+
@with_db_session(db_url=default_db_url)
214+
def test_handler_uploads_and_updates_gbfs_feed_info(self, db_session: Session):
215+
geo = {
216+
"type": "FeatureCollection",
217+
"features": [
218+
{
219+
"type": "Feature",
220+
"geometry": {
221+
"type": "Point",
222+
"coordinates": [100.1234567, 0.9876543],
223+
},
224+
"properties": {"id": "node/1", "keep": "x"},
225+
}
226+
],
227+
}
228+
testing_gbfs_feed = db_session.query(Gbfsfeed).limit(1).first()
229+
self.assertIsNotNone(testing_gbfs_feed)
230+
feed_stable_id = testing_gbfs_feed.stable_id
231+
blob_name = f"{feed_stable_id}/{GEOLOCATION_FILENAME}"
232+
233+
fake_bucket = FakeBucket(initial_blobs={blob_name: json.dumps(geo)})
234+
fake_storage = FakeStorageModule(fake_bucket, blob_exists=True)
235+
236+
# create module objects for google and google.cloud and inject via sys.modules
237+
cloud_mod = types.ModuleType("google.cloud")
238+
# 'from google.cloud import storage' in handler will bind 'storage' to this attribute
239+
cloud_mod.storage = fake_storage
240+
google_mod = types.ModuleType("google")
241+
google_mod.cloud = cloud_mod
242+
243+
payload = {
244+
"bucket_name": "any-bucket",
245+
"dry_run": False,
246+
"data_type": "gbfs",
247+
"precision": 5,
248+
"limit": 1,
249+
}
250+
251+
# Inject modules into sys.modules for the duration of the handler call
252+
with patch.dict(sys.modules, {"google.cloud": cloud_mod, "google": google_mod}):
253+
# call wrapped handler to provide fake db_session
254+
result = update_geojson_files_precision_handler(
255+
payload, db_session=db_session
256+
)
257+
258+
# verify upload happened
259+
self.assertIn(blob_name, fake_bucket.uploaded)
260+
uploaded_text = fake_bucket.uploaded[blob_name]
261+
uploaded_geo = json.loads(uploaded_text)
262+
coords = uploaded_geo.get("features")[0]["geometry"]["coordinates"]
263+
self.assertEqual(coords, [round(100.1234567, 5), round(0.9876543, 5)])
264+
265+
self.assertEqual(
266+
{
267+
"total_processed_files": 1,
268+
"errors": [],
269+
"not_found_file": 0,
270+
"params": {
271+
"dry_run": False,
272+
"precision": 5,
273+
"limit": 1,
274+
},
275+
},
276+
result,
277+
)
278+
# feed updated
279+
reloaded_testing_feed = (
280+
db_session.query(Gbfsfeed)
281+
.filter(Gbfsfeed.id.__eq__(testing_gbfs_feed.id))
282+
.limit(1)
283+
.first()
284+
)
285+
self.assertIsNone(reloaded_testing_feed.geolocation_file_dataset_id)
286+
self.assertIsNotNone(reloaded_testing_feed.geolocation_file_created_date)
287+
212288
@with_db_session(db_url=default_db_url)
213289
def test_handler_file_dont_exists(self, db_session: Session):
214290
fake_bucket = FakeBucket(initial_blobs={})

infra/functions-python/main.tf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,7 @@ resource "google_cloudfunctions2_function" "tasks_executor" {
11541154
DATASET_PROCESSING_TOPIC_NAME = "datasets-batch-topic-${var.environment}"
11551155
MATERIALIZED_VIEW_QUEUE = google_cloud_tasks_queue.refresh_materialized_view_task_queue.name
11561156
DATASETS_BUCKET_NAME = "${var.datasets_bucket_name}-${var.environment}"
1157+
GBFS_SNAPSHOTS_BUCKET_NAME = google_storage_bucket.gbfs_snapshots_bucket.name
11571158
}
11581159
available_memory = local.function_tasks_executor_config.memory
11591160
timeout_seconds = local.function_tasks_executor_config.timeout

0 commit comments

Comments
 (0)