Skip to content

Commit 1177beb

Browse files
committed
update feed geolocation information
1 parent f5f51a9 commit 1177beb

File tree

4 files changed

+128
-52
lines changed

4 files changed

+128
-52
lines changed

api/src/shared/common/gcp_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import json
22
import logging
33
import os
4-
from google.cloud import tasks_v2
5-
from google.protobuf.timestamp_pb2 import Timestamp
64

75
REFRESH_VIEW_TASK_EXECUTOR_BODY = json.dumps(
86
{"task": "refresh_materialized_view", "payload": {"dry_run": False}}
97
).encode()
108

119

1210
def create_refresh_materialized_view_task():
11+
from google.cloud import tasks_v2
12+
1313
"""
1414
Asynchronously refresh a materialized view.
1515
Ensures deduplication by generating a unique task name.
@@ -70,18 +70,19 @@ def create_refresh_materialized_view_task():
7070

7171

7272
def create_http_task_with_name(
73-
client: "tasks_v2.CloudTasksClient",
73+
client: any, # tasks_v2.CloudTasksClient
7474
body: bytes,
7575
url: str,
7676
project_id: str,
7777
gcp_region: str,
7878
queue_name: str,
7979
task_name: str,
80-
task_time: Timestamp,
81-
http_method: "tasks_v2.HttpMethod",
80+
task_time,
81+
http_method: any, # tasks_v2.HttpMethod
8282
timeout_s: int = 1800, # 30 minutes
8383
):
8484
"""Creates a GCP Cloud Task."""
85+
from google.cloud import tasks_v2
8586
from google.protobuf import duration_pb2
8687

8788
token = tasks_v2.OidcToken(service_account_email=os.getenv("SERVICE_ACCOUNT_EMAIL"))

api/src/shared/database/database.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import uuid
66
from typing import Type, Callable
77
from dotenv import load_dotenv
8-
from sqlalchemy import create_engine, text, event
8+
from sqlalchemy import create_engine, text, event, func, select
99
from sqlalchemy.orm import load_only, Query, class_mapper, Session, mapper
1010
from shared.database_gen.sqlacodegen_models import (
1111
Base,
@@ -33,6 +33,11 @@ def generate_unique_id() -> str:
3333
return str(uuid.uuid4())
3434

3535

36+
def get_db_timestamp(db_session: Session) -> func.current_timestamp():
37+
"""Get the current time from the database."""
38+
return db_session.execute(select(func.current_timestamp())).scalar()
39+
40+
3641
def configure_polymorphic_mappers():
3742
"""
3843
Configure the polymorphic mappers allowing polymorphic values on relationships.

functions-python/reverse_geolocation/src/reverse_geolocation_processor.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from parse_request import parse_request_parameters
2828
from shared.common.gcp_utils import create_refresh_materialized_view_task
29-
from shared.database.database import with_db_session
29+
from shared.database.database import with_db_session, get_db_timestamp
3030
from shared.database_gen.sqlacodegen_models import (
3131
Feed,
3232
Feedlocationgrouppoint,
@@ -134,15 +134,18 @@ def clean_stop_cache(db_session, feed, geometries_to_delete, logger):
134134
db_session.commit()
135135

136136

137+
@with_db_session
137138
def create_geojson_aggregate(
138139
location_groups: List[GeopolygonAggregate],
139140
total_stops: int,
140-
stable_id: str,
141141
bounding_box: shapely.Polygon,
142142
data_type: str,
143143
logger,
144+
feed: Feed,
145+
gtfs_dataset: Gtfsdataset = None,
144146
extraction_urls: List[str] = None,
145147
public: bool = True,
148+
db_session: Session = None,
146149
) -> None:
147150
"""Create a GeoJSON file with the aggregated locations. This file will be uploaded to GCS and used for
148151
visualization."""
@@ -197,10 +200,13 @@ def create_geojson_aggregate(
197200
else:
198201
raise ValueError("The data type must be either 'gtfs' or 'gbfs'.")
199202
bucket = storage_client.bucket(bucket_name)
200-
blob = bucket.blob(f"{stable_id}/geolocation.geojson")
203+
blob = bucket.blob(f"{feed.stable_id}/geolocation.geojson")
201204
blob.upload_from_string(json.dumps(json_data))
202205
if public:
203206
blob.make_public()
207+
feed.geolocation_file_created_date = get_db_timestamp(db_session)
208+
if gtfs_dataset:
209+
feed.geolocation_file_dataset = gtfs_dataset
204210
logger.info("GeoJSON data saved to %s", blob.name)
205211

206212

@@ -210,10 +216,9 @@ def get_storage_client():
210216
return storage.Client()
211217

212218

213-
@with_db_session
214219
@track_metrics(metrics=("time", "memory", "cpu"))
215220
def update_dataset_bounding_box(
216-
dataset_id: str, stops_df: pd.DataFrame, db_session: Session
221+
gtfs_dataset: Gtfsdataset, stops_df: pd.DataFrame, db_session: Session
217222
) -> shapely.Polygon:
218223
"""
219224
Update the bounding box of the dataset using the stops DataFrame.
@@ -231,19 +236,12 @@ def update_dataset_bounding_box(
231236
f")",
232237
srid=4326,
233238
)
234-
if not dataset_id:
235-
return to_shape(bounding_box)
236-
gtfs_dataset = (
237-
db_session.query(Gtfsdataset)
238-
.filter(Gtfsdataset.stable_id == dataset_id)
239-
.one_or_none()
240-
)
241239
if not gtfs_dataset:
242-
raise ValueError(f"Dataset {dataset_id} does not exist in the database.")
240+
return to_shape(bounding_box)
243241
gtfs_feed = db_session.get(Gtfsfeed, gtfs_dataset.feed_id)
244242
if not gtfs_feed:
245243
raise ValueError(
246-
f"GTFS feed for dataset {dataset_id} does not exist in the database."
244+
f"GTFS feed for dataset {gtfs_dataset.stable_id} does not exist in the database."
247245
)
248246
gtfs_feed.bounding_box = bounding_box
249247
gtfs_feed.bounding_box_dataset = gtfs_dataset
@@ -252,8 +250,22 @@ def update_dataset_bounding_box(
252250
return to_shape(bounding_box)
253251

254252

253+
def load_dataset(dataset_id: str, db_session: Session) -> Gtfsdataset:
254+
gtfs_dataset = (
255+
db_session.query(Gtfsdataset)
256+
.filter(Gtfsdataset.stable_id == dataset_id)
257+
.one_or_none()
258+
)
259+
if not gtfs_dataset:
260+
raise ValueError(
261+
f"Dataset with ID {dataset_id} does not exist in the database."
262+
)
263+
return gtfs_dataset
264+
265+
266+
@with_db_session()
255267
def reverse_geolocation_process(
256-
request: flask.Request,
268+
request: flask.Request, db_session: Session = None
257269
) -> Tuple[str, int] | Tuple[Dict, int]:
258270
"""
259271
Main function to handle reverse geolocation processing.
@@ -331,14 +343,35 @@ def reverse_geolocation_process(
331343

332344
try:
333345
# Update the bounding box of the dataset
334-
bounding_box = update_dataset_bounding_box(dataset_id, stops_df)
346+
if dataset_id:
347+
gtfs_dataset = load_dataset(dataset_id, db_session)
348+
feed = gtfs_dataset.feed
349+
if not feed:
350+
feed = (
351+
db_session.query(Feed).filter(Feed.stable_id == stable_id).one_or_none()
352+
)
353+
if not feed:
354+
no_feed_message = f"No feed found for stable ID {stable_id}."
355+
logger.warning(no_feed_message)
356+
record_execution_trace(
357+
execution_id=execution_id,
358+
stable_id=stable_id,
359+
status=Status.FAILED,
360+
logger=logger,
361+
dataset_file=None,
362+
error_message=no_feed_message,
363+
)
364+
return no_feed_message, ERROR_STATUS_CODE
365+
366+
bounding_box = update_dataset_bounding_box(gtfs_dataset, stops_df, db_session)
335367

336368
location_groups = reverse_geolocation(
337369
strategy=strategy,
338370
stable_id=stable_id,
339371
stops_df=stops_df,
340372
logger=logger,
341373
use_cache=use_cache,
374+
db_session=db_session,
342375
)
343376

344377
if not location_groups:
@@ -364,8 +397,14 @@ def reverse_geolocation_process(
364397
extraction_urls=extraction_urls,
365398
logger=logger,
366399
public=public,
400+
feed=feed,
401+
gtfs_dataset=gtfs_dataset,
402+
db_session=db_session,
367403
)
368404

405+
# Commit the changes to the database
406+
db_session.commit()
407+
create_refresh_materialized_view_task()
369408
logger.info(
370409
"COMPLETED. Processed %s stops for stable ID %s with strategy. "
371410
"Retrieved %s locations.",
@@ -453,7 +492,6 @@ def reverse_geolocation(
453492
logger=logger,
454493
db_session=db_session,
455494
)
456-
create_refresh_materialized_view_task()
457495
return cache_location_groups
458496

459497

@@ -508,5 +546,3 @@ def update_feed_location(
508546
gtfs_rt_feed.locations = feed_locations
509547
if feed_locations:
510548
feed.locations = feed_locations
511-
# Commit the changes to the database
512-
db_session.commit()

0 commit comments

Comments
 (0)