Skip to content

Commit d1843a6

Browse files
authored
feat: update feed geolocation information (#1361)
1 parent f5f51a9 commit d1843a6

File tree

9 files changed

+145
-62
lines changed

9 files changed

+145
-62
lines changed

api/src/shared/common/gcp_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import json
22
import logging
33
import os
4-
from google.cloud import tasks_v2
5-
from google.protobuf.timestamp_pb2 import Timestamp
64

75
REFRESH_VIEW_TASK_EXECUTOR_BODY = json.dumps(
86
{"task": "refresh_materialized_view", "payload": {"dry_run": False}}
97
).encode()
108

119

1210
def create_refresh_materialized_view_task():
11+
from google.cloud import tasks_v2
12+
1313
"""
1414
Asynchronously refresh a materialized view.
1515
Ensures deduplication by generating a unique task name.
@@ -70,18 +70,19 @@ def create_refresh_materialized_view_task():
7070

7171

7272
def create_http_task_with_name(
73-
client: "tasks_v2.CloudTasksClient",
73+
client: any, # tasks_v2.CloudTasksClient
7474
body: bytes,
7575
url: str,
7676
project_id: str,
7777
gcp_region: str,
7878
queue_name: str,
7979
task_name: str,
80-
task_time: Timestamp,
81-
http_method: "tasks_v2.HttpMethod",
80+
task_time,
81+
http_method: any, # tasks_v2.HttpMethod
8282
timeout_s: int = 1800, # 30 minutes
8383
):
8484
"""Creates a GCP Cloud Task."""
85+
from google.cloud import tasks_v2
8586
from google.protobuf import duration_pb2
8687

8788
token = tasks_v2.OidcToken(service_account_email=os.getenv("SERVICE_ACCOUNT_EMAIL"))

api/src/shared/database/database.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66
from typing import Type, Callable
77
from dotenv import load_dotenv
8-
from sqlalchemy import create_engine, text, event
8+
from sqlalchemy import create_engine, text, event, func, select
99
from sqlalchemy.orm import load_only, Query, class_mapper, Session, mapper
1010
from shared.database_gen.sqlacodegen_models import (
1111
Base,
@@ -33,6 +33,11 @@ def generate_unique_id() -> str:
3333
return str(uuid.uuid4())
3434

3535

36+
def get_db_timestamp(db_session: Session) -> func.current_timestamp():
37+
"""Get the current time from the database."""
38+
return db_session.execute(select(func.current_timestamp())).scalar()
39+
40+
3641
def configure_polymorphic_mappers():
3742
"""
3843
Configure the polymorphic mappers allowing polymorphic values on relationships.

functions-python/reverse_geolocation/src/reverse_geolocation_processor.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
)
2727
from parse_request import parse_request_parameters
2828
from shared.common.gcp_utils import create_refresh_materialized_view_task
29-
from shared.database.database import with_db_session
29+
from shared.database.database import with_db_session, get_db_timestamp
3030
from shared.database_gen.sqlacodegen_models import (
3131
Feed,
3232
Feedlocationgrouppoint,
3333
Osmlocationgroup,
3434
Gtfsdataset,
3535
Gtfsfeed,
36+
Gbfsfeed,
3637
)
3738
from shared.dataset_service.dataset_service_commons import Status
3839

@@ -134,15 +135,18 @@ def clean_stop_cache(db_session, feed, geometries_to_delete, logger):
134135
db_session.commit()
135136

136137

138+
@with_db_session
137139
def create_geojson_aggregate(
138140
location_groups: List[GeopolygonAggregate],
139141
total_stops: int,
140-
stable_id: str,
141142
bounding_box: shapely.Polygon,
142143
data_type: str,
143144
logger,
145+
feed: Gtfsfeed | Gbfsfeed,
146+
gtfs_dataset: Gtfsdataset = None,
144147
extraction_urls: List[str] = None,
145148
public: bool = True,
149+
db_session: Session = None,
146150
) -> None:
147151
"""Create a GeoJSON file with the aggregated locations. This file will be uploaded to GCS and used for
148152
visualization."""
@@ -197,10 +201,13 @@ def create_geojson_aggregate(
197201
else:
198202
raise ValueError("The data type must be either 'gtfs' or 'gbfs'.")
199203
bucket = storage_client.bucket(bucket_name)
200-
blob = bucket.blob(f"{stable_id}/geolocation.geojson")
204+
blob = bucket.blob(f"{feed.stable_id}/geolocation.geojson")
201205
blob.upload_from_string(json.dumps(json_data))
202206
if public:
203207
blob.make_public()
208+
feed.geolocation_file_created_date = get_db_timestamp(db_session)
209+
if gtfs_dataset:
210+
feed.geolocation_file_dataset = gtfs_dataset
204211
logger.info("GeoJSON data saved to %s", blob.name)
205212

206213

@@ -210,10 +217,9 @@ def get_storage_client():
210217
return storage.Client()
211218

212219

213-
@with_db_session
214220
@track_metrics(metrics=("time", "memory", "cpu"))
215221
def update_dataset_bounding_box(
216-
dataset_id: str, stops_df: pd.DataFrame, db_session: Session
222+
gtfs_dataset: Gtfsdataset, stops_df: pd.DataFrame, db_session: Session
217223
) -> shapely.Polygon:
218224
"""
219225
Update the bounding box of the dataset using the stops DataFrame.
@@ -231,19 +237,12 @@ def update_dataset_bounding_box(
231237
f")",
232238
srid=4326,
233239
)
234-
if not dataset_id:
235-
return to_shape(bounding_box)
236-
gtfs_dataset = (
237-
db_session.query(Gtfsdataset)
238-
.filter(Gtfsdataset.stable_id == dataset_id)
239-
.one_or_none()
240-
)
241240
if not gtfs_dataset:
242-
raise ValueError(f"Dataset {dataset_id} does not exist in the database.")
241+
return to_shape(bounding_box)
243242
gtfs_feed = db_session.get(Gtfsfeed, gtfs_dataset.feed_id)
244243
if not gtfs_feed:
245244
raise ValueError(
246-
f"GTFS feed for dataset {dataset_id} does not exist in the database."
245+
f"GTFS feed for dataset {gtfs_dataset.stable_id} does not exist in the database."
247246
)
248247
gtfs_feed.bounding_box = bounding_box
249248
gtfs_feed.bounding_box_dataset = gtfs_dataset
@@ -252,8 +251,22 @@ def update_dataset_bounding_box(
252251
return to_shape(bounding_box)
253252

254253

254+
def load_dataset(dataset_id: str, db_session: Session) -> Gtfsdataset:
255+
gtfs_dataset = (
256+
db_session.query(Gtfsdataset)
257+
.filter(Gtfsdataset.stable_id == dataset_id)
258+
.one_or_none()
259+
)
260+
if not gtfs_dataset:
261+
raise ValueError(
262+
f"Dataset with ID {dataset_id} does not exist in the database."
263+
)
264+
return gtfs_dataset
265+
266+
267+
@with_db_session()
255268
def reverse_geolocation_process(
256-
request: flask.Request,
269+
request: flask.Request, db_session: Session = None
257270
) -> Tuple[str, int] | Tuple[Dict, int]:
258271
"""
259272
Main function to handle reverse geolocation processing.
@@ -331,14 +344,21 @@ def reverse_geolocation_process(
331344

332345
try:
333346
# Update the bounding box of the dataset
334-
bounding_box = update_dataset_bounding_box(dataset_id, stops_df)
347+
gtfs_dataset: Gtfsdataset = None
348+
if dataset_id:
349+
gtfs_dataset = load_dataset(dataset_id, db_session)
350+
feed = load_feed(stable_id, data_type, logger, db_session)
351+
352+
bounding_box = update_dataset_bounding_box(gtfs_dataset, stops_df, db_session)
335353

336354
location_groups = reverse_geolocation(
337355
strategy=strategy,
338356
stable_id=stable_id,
339357
stops_df=stops_df,
358+
data_type=data_type,
340359
logger=logger,
341360
use_cache=use_cache,
361+
db_session=db_session,
342362
)
343363

344364
if not location_groups:
@@ -358,14 +378,19 @@ def reverse_geolocation_process(
358378
create_geojson_aggregate(
359379
list(location_groups.values()),
360380
total_stops=total_stops,
361-
stable_id=stable_id,
362381
bounding_box=bounding_box,
363382
data_type=data_type,
364383
extraction_urls=extraction_urls,
365384
logger=logger,
366385
public=public,
386+
feed=feed,
387+
gtfs_dataset=gtfs_dataset,
388+
db_session=db_session,
367389
)
368390

391+
# Commit the changes to the database
392+
db_session.commit()
393+
create_refresh_materialized_view_task()
369394
logger.info(
370395
"COMPLETED. Processed %s stops for stable ID %s with strategy. "
371396
"Retrieved %s locations.",
@@ -408,6 +433,7 @@ def reverse_geolocation(
408433
strategy,
409434
stable_id,
410435
stops_df,
436+
data_type,
411437
logger,
412438
use_cache,
413439
db_session: Session = None,
@@ -417,7 +443,7 @@ def reverse_geolocation(
417443
"""
418444
logger.info("Processing geopolygons with strategy: %s.", strategy)
419445

420-
feed = load_feed(stable_id, logger, db_session)
446+
feed = load_feed(stable_id, data_type, logger, db_session)
421447

422448
# Get Geopolygons with Geometry and cached location groups
423449
cache_location_groups, unmatched_stops_df = get_geopolygons_with_geometry(
@@ -453,13 +479,13 @@ def reverse_geolocation(
453479
logger=logger,
454480
db_session=db_session,
455481
)
456-
create_refresh_materialized_view_task()
457482
return cache_location_groups
458483

459484

460-
def load_feed(stable_id, logger, db_session):
485+
def load_feed(stable_id, data_type, logger, db_session) -> Gtfsfeed | Gbfsfeed:
486+
"""Load feed from the database using the stable ID and data type."""
461487
feed = (
462-
db_session.query(Feed)
488+
db_session.query(Gbfsfeed if data_type == "gbfs" else Gtfsfeed)
463489
.options(joinedload(Feed.feedlocationgrouppoints))
464490
.filter(Feed.stable_id == stable_id)
465491
.one_or_none()
@@ -508,5 +534,3 @@ def update_feed_location(
508534
gtfs_rt_feed.locations = feed_locations
509535
if feed_locations:
510536
feed.locations = feed_locations
511-
# Commit the changes to the database
512-
db_session.commit()

0 commit comments

Comments
 (0)